Skip to content

Commit 402c74a

Browse files
authored
Plumb Sparse Gaussian TileIntersection (#401)
Plumbs `dispatchGaussianSparseTileIntersection` into the appropriate sparse rendering functions. --------- Signed-off-by: Jonathan Swartz <[email protected]>
1 parent 6c09178 commit 402c74a

19 files changed

+417
-243
lines changed

src/fvdb/GaussianSplat3d.cpp

Lines changed: 153 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,141 @@ GaussianSplat3d::projectGaussiansImpl(const torch::Tensor &worldToCameraMatrices
389389
return ret;
390390
}
391391

392+
GaussianSplat3d::SparseProjectedGaussianSplats
393+
GaussianSplat3d::sparseProjectGaussiansImpl(const JaggedTensor &pixelsToRender,
394+
const torch::Tensor &worldToCameraMatrices,
395+
const torch::Tensor &projectionMatrices,
396+
const RenderSettings &settings) {
397+
FVDB_FUNC_RANGE();
398+
const bool ortho = settings.projectionType == fvdb::detail::ops::ProjectionType::ORTHOGRAPHIC;
399+
const int C = worldToCameraMatrices.size(0); // number of cameras
400+
const int N = mMeans.size(0); // number of gaussians
401+
402+
TORCH_CHECK(worldToCameraMatrices.sizes() == torch::IntArrayRef({C, 4, 4}),
403+
"worldToCameraMatrices must have shape (C, 4, 4)");
404+
TORCH_CHECK(projectionMatrices.sizes() == torch::IntArrayRef({C, 3, 3}),
405+
"projectionMatrices must have shape (C, 3, 3)");
406+
TORCH_CHECK(worldToCameraMatrices.is_contiguous(), "worldToCameraMatrices must be contiguous");
407+
TORCH_CHECK(projectionMatrices.is_contiguous(), "projectionMatrices must be contiguous");
408+
TORCH_CHECK(static_cast<int64_t>(pixelsToRender.num_outer_lists()) == C,
409+
"pixelsToRender must have the same number of outer lists as the number of cameras. "
410+
"Got ",
411+
pixelsToRender.num_outer_lists(),
412+
" outer lists but ",
413+
C,
414+
" cameras. ");
415+
416+
SparseProjectedGaussianSplats ret;
417+
ret.mRenderSettings = settings;
418+
419+
// Compute sparse tile info first (this determines which tiles are active)
420+
const int numTilesW = std::ceil(settings.imageWidth / static_cast<float>(settings.tileSize));
421+
const int numTilesH = std::ceil(settings.imageHeight / static_cast<float>(settings.tileSize));
422+
423+
const auto [activeTiles, activeTileMask, tilePixelMask, tilePixelCumsum, pixelMap] =
424+
fvdb::detail::ops::computeSparseInfo(
425+
settings.tileSize, numTilesW, numTilesH, pixelsToRender);
426+
427+
ret.activeTiles = activeTiles;
428+
ret.activeTileMask = activeTileMask;
429+
ret.tilePixelMask = tilePixelMask;
430+
ret.tilePixelCumsum = tilePixelCumsum;
431+
ret.pixelMap = pixelMap;
432+
433+
// Track gradients for the 2D means in the backward pass if you're optimizing
434+
std::optional<torch::Tensor> maybeNormalizedMeans2dGradientNorms = std::nullopt;
435+
std::optional<torch::Tensor> maybePerGaussianRadiiForGrad = std::nullopt;
436+
std::optional<torch::Tensor> maybeGradientStepCount = std::nullopt;
437+
if (mAccumulateMean2dGradients) {
438+
if (mAccumulatedNormalized2dMeansGradientNormsForGrad.numel() != N) {
439+
mAccumulatedNormalized2dMeansGradientNormsForGrad = torch::zeros({N}, mMeans.options());
440+
}
441+
if (mGradientStepCountForGrad.numel() != N) {
442+
mGradientStepCountForGrad = torch::zeros(
443+
{N}, torch::TensorOptions().dtype(torch::kInt32).device(mMeans.device()));
444+
}
445+
maybeNormalizedMeans2dGradientNorms = mAccumulatedNormalized2dMeansGradientNormsForGrad;
446+
maybeGradientStepCount = mGradientStepCountForGrad;
447+
}
448+
if (mAccumulateMax2dRadii) {
449+
if (mAccumulated2dRadiiForGrad.numel() != N && mAccumulateMax2dRadii) {
450+
mAccumulated2dRadiiForGrad = torch::zeros(
451+
{N}, torch::TensorOptions().dtype(torch::kInt32).device(mMeans.device()));
452+
}
453+
maybePerGaussianRadiiForGrad = mAccumulated2dRadiiForGrad;
454+
}
455+
456+
// Project to image plane
457+
const auto projectionResults =
458+
detail::autograd::ProjectGaussians::apply(mMeans,
459+
mQuats,
460+
mLogScales,
461+
worldToCameraMatrices,
462+
projectionMatrices,
463+
settings.imageWidth,
464+
settings.imageHeight,
465+
settings.eps2d,
466+
settings.nearPlane,
467+
settings.farPlane,
468+
settings.radiusClip,
469+
settings.antialias,
470+
ortho,
471+
maybeNormalizedMeans2dGradientNorms,
472+
maybePerGaussianRadiiForGrad,
473+
maybeGradientStepCount);
474+
ret.perGaussianRadius = projectionResults[0];
475+
ret.perGaussian2dMean = projectionResults[1];
476+
ret.perGaussianDepth = projectionResults[2];
477+
ret.perGaussianConic = projectionResults[3];
478+
// FIXME: Use accessors in the kernel and use expand
479+
ret.perGaussianOpacity = opacities().repeat({C, 1});
480+
if (settings.antialias) {
481+
ret.perGaussianOpacity *= projectionResults[4];
482+
// FIXME (Francis): The contiguity requirement is dumb and should be
483+
// removed by using accessors in the kernel
484+
ret.perGaussianOpacity = ret.perGaussianOpacity.contiguous();
485+
}
486+
487+
ret.perGaussianRenderQuantity = [&]() {
488+
torch::Tensor renderQuantity;
489+
if (settings.renderMode == RenderMode::DEPTH) {
490+
renderQuantity = ret.perGaussianDepth.unsqueeze(-1); // [C, N, 1]
491+
} else if (settings.renderMode == RenderMode::RGB ||
492+
settings.renderMode == RenderMode::RGBD) {
493+
renderQuantity = evalSphericalHarmonicsImpl(
494+
settings.shDegreeToUse, worldToCameraMatrices, ret.perGaussianRadius);
495+
496+
if (settings.renderMode == RenderMode::RGBD) {
497+
renderQuantity = torch::cat({renderQuantity, ret.perGaussianDepth.unsqueeze(-1)},
498+
-1); // [C, N, D + 1]
499+
}
500+
} else {
501+
TORCH_CHECK_VALUE(false, "Invalid render mode");
502+
}
503+
return renderQuantity;
504+
}();
505+
506+
// Intersect projected Gaussians with image tiles [non-differentiable]
507+
// Use sparse tile intersection which only computes intersections for active tiles
508+
const auto [sparseTileOffsets, tileGaussianIds] = FVDB_DISPATCH_KERNEL(mMeans.device(), [&]() {
509+
return detail::ops::dispatchGaussianSparseTileIntersection<DeviceTag>(ret.perGaussian2dMean,
510+
ret.perGaussianRadius,
511+
ret.perGaussianDepth,
512+
ret.activeTileMask,
513+
ret.activeTiles,
514+
at::nullopt,
515+
C,
516+
settings.tileSize,
517+
numTilesH,
518+
numTilesW);
519+
});
520+
// Use sparse 1D tile offsets - RasterizeCommonArgs detects the format from dimensions
521+
ret.tileOffsets = sparseTileOffsets; // [num_active_tiles + 1]
522+
ret.tileGaussianIds = tileGaussianIds; // [TOT_INTERSECTIONS]
523+
524+
return ret;
525+
}
526+
392527
std::tuple<torch::Tensor, torch::Tensor>
393528
GaussianSplat3d::renderCropFromProjectedGaussiansImpl(
394529
const ProjectedGaussianSplats &projectedGaussians,
@@ -454,14 +589,8 @@ GaussianSplat3d::sparseRenderImpl(const JaggedTensor &pixelsToRender,
454589
const fvdb::detail::ops::RenderSettings &settings) {
455590
FVDB_FUNC_RANGE();
456591

457-
const ProjectedGaussianSplats &state =
458-
projectGaussiansImpl(worldToCameraMatrices, projectionMatrices, settings);
459-
460-
const auto [activeTiles, activeTileMask, tilePixelMask, tilePixelCumsum, pixelMap] =
461-
fvdb::detail::ops::computeSparseInfo(settings.tileSize,
462-
state.tileOffsets.size(2),
463-
state.tileOffsets.size(1),
464-
pixelsToRender);
592+
const SparseProjectedGaussianSplats &state = sparseProjectGaussiansImpl(
593+
pixelsToRender, worldToCameraMatrices, projectionMatrices, settings);
465594

466595
auto rasterizeResult =
467596
detail::autograd::RasterizeGaussiansToPixelsSparse::apply(pixelsToRender,
@@ -476,10 +605,10 @@ GaussianSplat3d::sparseRenderImpl(const JaggedTensor &pixelsToRender,
476605
settings.tileSize,
477606
state.tileOffsets,
478607
state.tileGaussianIds,
479-
activeTiles,
480-
tilePixelMask,
481-
tilePixelCumsum,
482-
pixelMap,
608+
state.activeTiles,
609+
state.tilePixelMask,
610+
state.tilePixelCumsum,
611+
state.pixelMap,
483612
false);
484613
auto renderedPixelsJData = rasterizeResult[0];
485614

@@ -514,14 +643,9 @@ GaussianSplat3d::sparseRenderNumContributingGaussiansImpl(
514643
const torch::Tensor &projectionMatrices,
515644
const fvdb::detail::ops::RenderSettings &settings) {
516645
FVDB_FUNC_RANGE();
517-
const ProjectedGaussianSplats &state =
518-
projectGaussiansImpl(worldToCameraMatrices, projectionMatrices, settings);
519646

520-
const auto [activeTiles, activeTileMask, tilePixelMask, tilePixelCumsum, pixelMap] =
521-
fvdb::detail::ops::computeSparseInfo(settings.tileSize,
522-
state.tileOffsets.size(2),
523-
state.tileOffsets.size(1),
524-
pixelsToRender);
647+
const SparseProjectedGaussianSplats &state = sparseProjectGaussiansImpl(
648+
pixelsToRender, worldToCameraMatrices, projectionMatrices, settings);
525649

526650
return FVDB_DISPATCH_KERNEL_DEVICE(state.perGaussian2dMean.device(), [&]() {
527651
return fvdb::detail::ops::dispatchGaussianSparseRasterizeNumContributingGaussians<
@@ -531,10 +655,10 @@ GaussianSplat3d::sparseRenderNumContributingGaussiansImpl(
531655
state.tileOffsets,
532656
state.tileGaussianIds,
533657
pixelsToRender,
534-
activeTiles,
535-
tilePixelMask,
536-
tilePixelCumsum,
537-
pixelMap,
658+
state.activeTiles,
659+
state.tilePixelMask,
660+
state.tilePixelCumsum,
661+
state.pixelMap,
538662
settings);
539663
});
540664
}
@@ -585,14 +709,9 @@ GaussianSplat3d::sparseRenderContributingGaussianIdsImpl(
585709
const fvdb::detail::ops::RenderSettings &settings,
586710
const std::optional<fvdb::JaggedTensor> &maybeNumContributingGaussians) {
587711
FVDB_FUNC_RANGE();
588-
const ProjectedGaussianSplats &state =
589-
projectGaussiansImpl(worldToCameraMatrices, projectionMatrices, settings);
590712

591-
const auto [activeTiles, activeTileMask, tilePixelMask, tilePixelCumsum, pixelMap] =
592-
fvdb::detail::ops::computeSparseInfo(settings.tileSize,
593-
state.tileOffsets.size(2),
594-
state.tileOffsets.size(1),
595-
pixelsToRender);
713+
const SparseProjectedGaussianSplats &state = sparseProjectGaussiansImpl(
714+
pixelsToRender, worldToCameraMatrices, projectionMatrices, settings);
596715

597716
return FVDB_DISPATCH_KERNEL_DEVICE(state.perGaussian2dMean.device(), [&]() {
598717
return fvdb::detail::ops::dispatchGaussianSparseRasterizeContributingGaussianIds<DeviceTag>(
@@ -602,10 +721,10 @@ GaussianSplat3d::sparseRenderContributingGaussianIdsImpl(
602721
state.tileOffsets,
603722
state.tileGaussianIds,
604723
pixelsToRender,
605-
activeTiles,
606-
tilePixelMask,
607-
tilePixelCumsum,
608-
pixelMap,
724+
state.activeTiles,
725+
state.tilePixelMask,
726+
state.tilePixelCumsum,
727+
state.pixelMap,
609728
settings,
610729
maybeNumContributingGaussians);
611730
});

src/fvdb/GaussianSplat3d.h

Lines changed: 23 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,17 @@ class GaussianSplat3d {
150150
}
151151
};
152152

153+
/// @brief A set of projected Gaussians with sparse tile intersection data for sparse rendering.
154+
/// This struct extends ProjectedGaussianSplats with additional sparse-specific tensors.
155+
struct SparseProjectedGaussianSplats : public ProjectedGaussianSplats {
156+
torch::Tensor activeTiles; // [num_active_tiles] - tile IDs of active tiles
157+
torch::Tensor activeTileMask; // [C, TH, TW] - boolean mask of active tiles
158+
torch::Tensor tilePixelMask; // [num_active_tiles, words_per_tile] - bitmask of pixels
159+
torch::Tensor tilePixelCumsum; // [num_active_tiles] - cumulative sum of active pixels
160+
torch::Tensor pixelMap; // [num_active_pixels] - mapping for pixel write order
161+
// Note: tileOffsets (inherited) is 1D [num_active_tiles + 1] in sparse mode
162+
};
163+
153164
public:
154165
/// @brief Concatenate a vector of GaussianSplat3d objects into a single GaussianSplat3d object.
155166
/// @param splats A vector of GaussianSplat3d objects to concatenate.
@@ -1284,6 +1295,18 @@ class GaussianSplat3d {
12841295
const torch::Tensor &projectionMatrices,
12851296
const fvdb::detail::ops::RenderSettings &settings);
12861297

1298+
/// @brief Project Gaussians with sparse tile intersection for efficient sparse rendering.
1299+
/// @param pixelsToRender JaggedTensor of pixel coordinates to render [P1 + P2 + ..., 2]
1300+
/// @param worldToCameraMatrices [C, 4, 4] Camera-to-world matrices
1301+
/// @param projectionMatrices [C, 3, 3] Projection matrices
1302+
/// @param settings Render settings
1303+
/// @return SparseProjectedGaussianSplats containing projected Gaussians and sparse tile data
1304+
SparseProjectedGaussianSplats
1305+
sparseProjectGaussiansImpl(const JaggedTensor &pixelsToRender,
1306+
const torch::Tensor &worldToCameraMatrices,
1307+
const torch::Tensor &projectionMatrices,
1308+
const fvdb::detail::ops::RenderSettings &settings);
1309+
12871310
std::tuple<torch::Tensor, torch::Tensor> renderCropFromProjectedGaussiansImpl(
12881311
const ProjectedGaussianSplats &state,
12891312
const size_t tileSize,
@@ -1362,56 +1385,6 @@ class GaussianSplat3d {
13621385
const torch::Tensor &projectionMatrices,
13631386
const fvdb::detail::ops::RenderSettings &settings);
13641387

1365-
/// @brief Render the gaussian splatting scene
1366-
/// For every pixel being rendered, this function returns multiple samples in depth of
1367-
/// the gaussian IDs and multiple samples of the weighted alpha values. The number of
1368-
/// samples per pixel is determined by the sampling parameters in the settings. If
1369-
/// the size of the requested number of samples is greater than the number of
1370-
/// contributing samples for a pixel, the remaining samples' weights are filled with
1371-
/// zeros and the IDs are filled with -1. The samples are ordered front to back in
1372-
/// their depth ordering from camera.
1373-
/// @param worldToCameraMatrices [C, 4, 4]
1374-
/// @param projectionMatrices [C, 3, 3]
1375-
/// @param settings
1376-
/// @return Tuple of two tensors:
1377-
/// ids: A [B, H, W, K] tensor containing the the IDs of the top K contributors to the
1378-
/// rendered pixel for each camera
1379-
/// weights: A [B, H, W, K] tensor containing the weights of the top K contributors to the
1380-
/// rendered pixel for each camera. The weights are normalized to sum to the alpha
1381-
/// value of the final rendered pixel if the list is exahustive of all contributing
1382-
/// samples.
1383-
std::tuple<torch::Tensor, torch::Tensor>
1384-
renderTopContributingGaussianIdsImpl(const torch::Tensor &worldToCameraMatrices,
1385-
const torch::Tensor &projectionMatrices,
1386-
const fvdb::detail::ops::RenderSettings &settings);
1387-
1388-
/// @brief Sparse render the gaussian splatting scene
1389-
/// For every pixel being rendered, this function returns multiple samples in depth of
1390-
/// the gaussian IDs and multiple samples of the weighted alpha values. The number of
1391-
/// samples per pixel is determined by the sampling parameters in the settings. If
1392-
/// the size of the requested number of samples is greater than the number of
1393-
/// contributing samples for a pixel, the remaining samples' weights are filled with
1394-
/// zeros and the IDs are filled with -1. The samples are ordered front to back in
1395-
/// their depth ordering from camera.
1396-
/// @param pixelsToRender [P1 + P2 + ..., 2] JaggedTensor of pixels per camera to render.
1397-
/// @param worldToCameraMatrices [C, 4, 4]
1398-
/// @param projectionMatrices [C, 3, 3]
1399-
/// @param settings
1400-
/// @return Tuple of two tensors:
1401-
/// ids: A [P1 + P2 + ..., K] jagged tensor containing the the IDs of the top K contributors
1402-
/// to the
1403-
/// rendered pixel for each camera
1404-
/// weights: A [P1 + P2 + ..., K] jagged tensor containing the weights of the top K
1405-
/// contributors to the
1406-
/// rendered pixel for each camera. The weights are normalized to sum to the alpha
1407-
/// value of the final rendered pixel if the list is exahustive of all contributing
1408-
/// samples.
1409-
std::tuple<fvdb::JaggedTensor, fvdb::JaggedTensor>
1410-
sparseRenderTopContributingGaussianIdsImpl(const fvdb::JaggedTensor &pixelsToRender,
1411-
const torch::Tensor &worldToCameraMatrices,
1412-
const torch::Tensor &projectionMatrices,
1413-
const fvdb::detail::ops::RenderSettings &settings);
1414-
14151388
/// @brief Render the gaussian splatting scene
14161389
/// For every pixel being rendered, this function returns multiple samples in depth of
14171390
/// the gaussian IDs and multiple samples of the weighted alpha values. The samples are

0 commit comments

Comments
 (0)