@@ -389,6 +389,141 @@ GaussianSplat3d::projectGaussiansImpl(const torch::Tensor &worldToCameraMatrices
389389 return ret;
390390}
391391
392+ GaussianSplat3d::SparseProjectedGaussianSplats
393+ GaussianSplat3d::sparseProjectGaussiansImpl (const JaggedTensor &pixelsToRender,
394+ const torch::Tensor &worldToCameraMatrices,
395+ const torch::Tensor &projectionMatrices,
396+ const RenderSettings &settings) {
397+ FVDB_FUNC_RANGE ();
398+ const bool ortho = settings.projectionType == fvdb::detail::ops::ProjectionType::ORTHOGRAPHIC;
399+ const int C = worldToCameraMatrices.size (0 ); // number of cameras
400+ const int N = mMeans .size (0 ); // number of gaussians
401+
402+ TORCH_CHECK (worldToCameraMatrices.sizes () == torch::IntArrayRef ({C, 4 , 4 }),
403+ " worldToCameraMatrices must have shape (C, 4, 4)" );
404+ TORCH_CHECK (projectionMatrices.sizes () == torch::IntArrayRef ({C, 3 , 3 }),
405+ " projectionMatrices must have shape (C, 3, 3)" );
406+ TORCH_CHECK (worldToCameraMatrices.is_contiguous (), " worldToCameraMatrices must be contiguous" );
407+ TORCH_CHECK (projectionMatrices.is_contiguous (), " projectionMatrices must be contiguous" );
408+ TORCH_CHECK (static_cast <int64_t >(pixelsToRender.num_outer_lists ()) == C,
409+ " pixelsToRender must have the same number of outer lists as the number of cameras. "
410+ " Got " ,
411+ pixelsToRender.num_outer_lists (),
412+ " outer lists but " ,
413+ C,
414+ " cameras. " );
415+
416+ SparseProjectedGaussianSplats ret;
417+ ret.mRenderSettings = settings;
418+
419+ // Compute sparse tile info first (this determines which tiles are active)
420+ const int numTilesW = std::ceil (settings.imageWidth / static_cast <float >(settings.tileSize ));
421+ const int numTilesH = std::ceil (settings.imageHeight / static_cast <float >(settings.tileSize ));
422+
423+ const auto [activeTiles, activeTileMask, tilePixelMask, tilePixelCumsum, pixelMap] =
424+ fvdb::detail::ops::computeSparseInfo (
425+ settings.tileSize , numTilesW, numTilesH, pixelsToRender);
426+
427+ ret.activeTiles = activeTiles;
428+ ret.activeTileMask = activeTileMask;
429+ ret.tilePixelMask = tilePixelMask;
430+ ret.tilePixelCumsum = tilePixelCumsum;
431+ ret.pixelMap = pixelMap;
432+
433+ // Track gradients for the 2D means in the backward pass if you're optimizing
434+ std::optional<torch::Tensor> maybeNormalizedMeans2dGradientNorms = std::nullopt ;
435+ std::optional<torch::Tensor> maybePerGaussianRadiiForGrad = std::nullopt ;
436+ std::optional<torch::Tensor> maybeGradientStepCount = std::nullopt ;
437+ if (mAccumulateMean2dGradients ) {
438+ if (mAccumulatedNormalized2dMeansGradientNormsForGrad .numel () != N) {
439+ mAccumulatedNormalized2dMeansGradientNormsForGrad = torch::zeros ({N}, mMeans .options ());
440+ }
441+ if (mGradientStepCountForGrad .numel () != N) {
442+ mGradientStepCountForGrad = torch::zeros (
443+ {N}, torch::TensorOptions ().dtype (torch::kInt32 ).device (mMeans .device ()));
444+ }
445+ maybeNormalizedMeans2dGradientNorms = mAccumulatedNormalized2dMeansGradientNormsForGrad ;
446+ maybeGradientStepCount = mGradientStepCountForGrad ;
447+ }
448+ if (mAccumulateMax2dRadii ) {
449+ if (mAccumulated2dRadiiForGrad .numel () != N && mAccumulateMax2dRadii ) {
450+ mAccumulated2dRadiiForGrad = torch::zeros (
451+ {N}, torch::TensorOptions ().dtype (torch::kInt32 ).device (mMeans .device ()));
452+ }
453+ maybePerGaussianRadiiForGrad = mAccumulated2dRadiiForGrad ;
454+ }
455+
456+ // Project to image plane
457+ const auto projectionResults =
458+ detail::autograd::ProjectGaussians::apply (mMeans ,
459+ mQuats ,
460+ mLogScales ,
461+ worldToCameraMatrices,
462+ projectionMatrices,
463+ settings.imageWidth ,
464+ settings.imageHeight ,
465+ settings.eps2d ,
466+ settings.nearPlane ,
467+ settings.farPlane ,
468+ settings.radiusClip ,
469+ settings.antialias ,
470+ ortho,
471+ maybeNormalizedMeans2dGradientNorms,
472+ maybePerGaussianRadiiForGrad,
473+ maybeGradientStepCount);
474+ ret.perGaussianRadius = projectionResults[0 ];
475+ ret.perGaussian2dMean = projectionResults[1 ];
476+ ret.perGaussianDepth = projectionResults[2 ];
477+ ret.perGaussianConic = projectionResults[3 ];
478+ // FIXME: Use accessors in the kernel and use expand
479+ ret.perGaussianOpacity = opacities ().repeat ({C, 1 });
480+ if (settings.antialias ) {
481+ ret.perGaussianOpacity *= projectionResults[4 ];
482+ // FIXME (Francis): The contiguity requirement is dumb and should be
483+ // removed by using accessors in the kernel
484+ ret.perGaussianOpacity = ret.perGaussianOpacity .contiguous ();
485+ }
486+
487+ ret.perGaussianRenderQuantity = [&]() {
488+ torch::Tensor renderQuantity;
489+ if (settings.renderMode == RenderMode::DEPTH) {
490+ renderQuantity = ret.perGaussianDepth .unsqueeze (-1 ); // [C, N, 1]
491+ } else if (settings.renderMode == RenderMode::RGB ||
492+ settings.renderMode == RenderMode::RGBD) {
493+ renderQuantity = evalSphericalHarmonicsImpl (
494+ settings.shDegreeToUse , worldToCameraMatrices, ret.perGaussianRadius );
495+
496+ if (settings.renderMode == RenderMode::RGBD) {
497+ renderQuantity = torch::cat ({renderQuantity, ret.perGaussianDepth .unsqueeze (-1 )},
498+ -1 ); // [C, N, D + 1]
499+ }
500+ } else {
501+ TORCH_CHECK_VALUE (false , " Invalid render mode" );
502+ }
503+ return renderQuantity;
504+ }();
505+
506+ // Intersect projected Gaussians with image tiles [non-differentiable]
507+ // Use sparse tile intersection which only computes intersections for active tiles
508+ const auto [sparseTileOffsets, tileGaussianIds] = FVDB_DISPATCH_KERNEL (mMeans .device (), [&]() {
509+ return detail::ops::dispatchGaussianSparseTileIntersection<DeviceTag>(ret.perGaussian2dMean ,
510+ ret.perGaussianRadius ,
511+ ret.perGaussianDepth ,
512+ ret.activeTileMask ,
513+ ret.activeTiles ,
514+ at::nullopt ,
515+ C,
516+ settings.tileSize ,
517+ numTilesH,
518+ numTilesW);
519+ });
520+ // Use sparse 1D tile offsets - RasterizeCommonArgs detects the format from dimensions
521+ ret.tileOffsets = sparseTileOffsets; // [num_active_tiles + 1]
522+ ret.tileGaussianIds = tileGaussianIds; // [TOT_INTERSECTIONS]
523+
524+ return ret;
525+ }
526+
392527std::tuple<torch::Tensor, torch::Tensor>
393528GaussianSplat3d::renderCropFromProjectedGaussiansImpl (
394529 const ProjectedGaussianSplats &projectedGaussians,
@@ -454,14 +589,8 @@ GaussianSplat3d::sparseRenderImpl(const JaggedTensor &pixelsToRender,
454589 const fvdb::detail::ops::RenderSettings &settings) {
455590 FVDB_FUNC_RANGE ();
456591
457- const ProjectedGaussianSplats &state =
458- projectGaussiansImpl (worldToCameraMatrices, projectionMatrices, settings);
459-
460- const auto [activeTiles, activeTileMask, tilePixelMask, tilePixelCumsum, pixelMap] =
461- fvdb::detail::ops::computeSparseInfo (settings.tileSize ,
462- state.tileOffsets .size (2 ),
463- state.tileOffsets .size (1 ),
464- pixelsToRender);
592+ const SparseProjectedGaussianSplats &state = sparseProjectGaussiansImpl (
593+ pixelsToRender, worldToCameraMatrices, projectionMatrices, settings);
465594
466595 auto rasterizeResult =
467596 detail::autograd::RasterizeGaussiansToPixelsSparse::apply (pixelsToRender,
@@ -476,10 +605,10 @@ GaussianSplat3d::sparseRenderImpl(const JaggedTensor &pixelsToRender,
476605 settings.tileSize ,
477606 state.tileOffsets ,
478607 state.tileGaussianIds ,
479- activeTiles,
480- tilePixelMask,
481- tilePixelCumsum,
482- pixelMap,
608+ state. activeTiles ,
609+ state. tilePixelMask ,
610+ state. tilePixelCumsum ,
611+ state. pixelMap ,
483612 false );
484613 auto renderedPixelsJData = rasterizeResult[0 ];
485614
@@ -514,14 +643,9 @@ GaussianSplat3d::sparseRenderNumContributingGaussiansImpl(
514643 const torch::Tensor &projectionMatrices,
515644 const fvdb::detail::ops::RenderSettings &settings) {
516645 FVDB_FUNC_RANGE ();
517- const ProjectedGaussianSplats &state =
518- projectGaussiansImpl (worldToCameraMatrices, projectionMatrices, settings);
519646
520- const auto [activeTiles, activeTileMask, tilePixelMask, tilePixelCumsum, pixelMap] =
521- fvdb::detail::ops::computeSparseInfo (settings.tileSize ,
522- state.tileOffsets .size (2 ),
523- state.tileOffsets .size (1 ),
524- pixelsToRender);
647+ const SparseProjectedGaussianSplats &state = sparseProjectGaussiansImpl (
648+ pixelsToRender, worldToCameraMatrices, projectionMatrices, settings);
525649
526650 return FVDB_DISPATCH_KERNEL_DEVICE (state.perGaussian2dMean .device (), [&]() {
527651 return fvdb::detail::ops::dispatchGaussianSparseRasterizeNumContributingGaussians<
@@ -531,10 +655,10 @@ GaussianSplat3d::sparseRenderNumContributingGaussiansImpl(
531655 state.tileOffsets ,
532656 state.tileGaussianIds ,
533657 pixelsToRender,
534- activeTiles,
535- tilePixelMask,
536- tilePixelCumsum,
537- pixelMap,
658+ state. activeTiles ,
659+ state. tilePixelMask ,
660+ state. tilePixelCumsum ,
661+ state. pixelMap ,
538662 settings);
539663 });
540664}
@@ -585,14 +709,9 @@ GaussianSplat3d::sparseRenderContributingGaussianIdsImpl(
585709 const fvdb::detail::ops::RenderSettings &settings,
586710 const std::optional<fvdb::JaggedTensor> &maybeNumContributingGaussians) {
587711 FVDB_FUNC_RANGE ();
588- const ProjectedGaussianSplats &state =
589- projectGaussiansImpl (worldToCameraMatrices, projectionMatrices, settings);
590712
591- const auto [activeTiles, activeTileMask, tilePixelMask, tilePixelCumsum, pixelMap] =
592- fvdb::detail::ops::computeSparseInfo (settings.tileSize ,
593- state.tileOffsets .size (2 ),
594- state.tileOffsets .size (1 ),
595- pixelsToRender);
713+ const SparseProjectedGaussianSplats &state = sparseProjectGaussiansImpl (
714+ pixelsToRender, worldToCameraMatrices, projectionMatrices, settings);
596715
597716 return FVDB_DISPATCH_KERNEL_DEVICE (state.perGaussian2dMean .device (), [&]() {
598717 return fvdb::detail::ops::dispatchGaussianSparseRasterizeContributingGaussianIds<DeviceTag>(
@@ -602,10 +721,10 @@ GaussianSplat3d::sparseRenderContributingGaussianIdsImpl(
602721 state.tileOffsets ,
603722 state.tileGaussianIds ,
604723 pixelsToRender,
605- activeTiles,
606- tilePixelMask,
607- tilePixelCumsum,
608- pixelMap,
724+ state. activeTiles ,
725+ state. tilePixelMask ,
726+ state. tilePixelCumsum ,
727+ state. pixelMap ,
609728 settings,
610729 maybeNumContributingGaussians);
611730 });
0 commit comments