Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 153 additions & 34 deletions src/fvdb/GaussianSplat3d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,141 @@ GaussianSplat3d::projectGaussiansImpl(const torch::Tensor &worldToCameraMatrices
return ret;
}

GaussianSplat3d::SparseProjectedGaussianSplats
GaussianSplat3d::sparseProjectGaussiansImpl(const JaggedTensor &pixelsToRender,
const torch::Tensor &worldToCameraMatrices,
const torch::Tensor &projectionMatrices,
const RenderSettings &settings) {
FVDB_FUNC_RANGE();
const bool ortho = settings.projectionType == fvdb::detail::ops::ProjectionType::ORTHOGRAPHIC;
const int C = worldToCameraMatrices.size(0); // number of cameras
const int N = mMeans.size(0); // number of gaussians

TORCH_CHECK(worldToCameraMatrices.sizes() == torch::IntArrayRef({C, 4, 4}),
"worldToCameraMatrices must have shape (C, 4, 4)");
TORCH_CHECK(projectionMatrices.sizes() == torch::IntArrayRef({C, 3, 3}),
"projectionMatrices must have shape (C, 3, 3)");
TORCH_CHECK(worldToCameraMatrices.is_contiguous(), "worldToCameraMatrices must be contiguous");
TORCH_CHECK(projectionMatrices.is_contiguous(), "projectionMatrices must be contiguous");
TORCH_CHECK(static_cast<int64_t>(pixelsToRender.num_outer_lists()) == C,
"pixelsToRender must have the same number of outer lists as the number of cameras. "
"Got ",
pixelsToRender.num_outer_lists(),
" outer lists but ",
C,
" cameras. ");

SparseProjectedGaussianSplats ret;
ret.mRenderSettings = settings;

// Compute sparse tile info first (this determines which tiles are active)
const int numTilesW = std::ceil(settings.imageWidth / static_cast<float>(settings.tileSize));
const int numTilesH = std::ceil(settings.imageHeight / static_cast<float>(settings.tileSize));

const auto [activeTiles, activeTileMask, tilePixelMask, tilePixelCumsum, pixelMap] =
fvdb::detail::ops::computeSparseInfo(
settings.tileSize, numTilesW, numTilesH, pixelsToRender);

ret.activeTiles = activeTiles;
ret.activeTileMask = activeTileMask;
ret.tilePixelMask = tilePixelMask;
ret.tilePixelCumsum = tilePixelCumsum;
ret.pixelMap = pixelMap;

// Track gradients for the 2D means in the backward pass if you're optimizing
std::optional<torch::Tensor> maybeNormalizedMeans2dGradientNorms = std::nullopt;
std::optional<torch::Tensor> maybePerGaussianRadiiForGrad = std::nullopt;
std::optional<torch::Tensor> maybeGradientStepCount = std::nullopt;
if (mAccumulateMean2dGradients) {
if (mAccumulatedNormalized2dMeansGradientNormsForGrad.numel() != N) {
mAccumulatedNormalized2dMeansGradientNormsForGrad = torch::zeros({N}, mMeans.options());
}
if (mGradientStepCountForGrad.numel() != N) {
mGradientStepCountForGrad = torch::zeros(
{N}, torch::TensorOptions().dtype(torch::kInt32).device(mMeans.device()));
}
maybeNormalizedMeans2dGradientNorms = mAccumulatedNormalized2dMeansGradientNormsForGrad;
maybeGradientStepCount = mGradientStepCountForGrad;
}
if (mAccumulateMax2dRadii) {
if (mAccumulated2dRadiiForGrad.numel() != N && mAccumulateMax2dRadii) {
mAccumulated2dRadiiForGrad = torch::zeros(
{N}, torch::TensorOptions().dtype(torch::kInt32).device(mMeans.device()));
}
maybePerGaussianRadiiForGrad = mAccumulated2dRadiiForGrad;
}

// Project to image plane
const auto projectionResults =
detail::autograd::ProjectGaussians::apply(mMeans,
mQuats,
mLogScales,
worldToCameraMatrices,
projectionMatrices,
settings.imageWidth,
settings.imageHeight,
settings.eps2d,
settings.nearPlane,
settings.farPlane,
settings.radiusClip,
settings.antialias,
ortho,
maybeNormalizedMeans2dGradientNorms,
maybePerGaussianRadiiForGrad,
maybeGradientStepCount);
ret.perGaussianRadius = projectionResults[0];
ret.perGaussian2dMean = projectionResults[1];
ret.perGaussianDepth = projectionResults[2];
ret.perGaussianConic = projectionResults[3];
// FIXME: Use accessors in the kernel and use expand
ret.perGaussianOpacity = opacities().repeat({C, 1});
if (settings.antialias) {
ret.perGaussianOpacity *= projectionResults[4];
// FIXME (Francis): The contiguity requirement is dumb and should be
// removed by using accessors in the kernel
ret.perGaussianOpacity = ret.perGaussianOpacity.contiguous();
}

ret.perGaussianRenderQuantity = [&]() {
torch::Tensor renderQuantity;
if (settings.renderMode == RenderMode::DEPTH) {
renderQuantity = ret.perGaussianDepth.unsqueeze(-1); // [C, N, 1]
} else if (settings.renderMode == RenderMode::RGB ||
settings.renderMode == RenderMode::RGBD) {
renderQuantity = evalSphericalHarmonicsImpl(
settings.shDegreeToUse, worldToCameraMatrices, ret.perGaussianRadius);

if (settings.renderMode == RenderMode::RGBD) {
renderQuantity = torch::cat({renderQuantity, ret.perGaussianDepth.unsqueeze(-1)},
-1); // [C, N, D + 1]
}
} else {
TORCH_CHECK_VALUE(false, "Invalid render mode");
}
return renderQuantity;
}();

// Intersect projected Gaussians with image tiles [non-differentiable]
// Use sparse tile intersection which only computes intersections for active tiles
const auto [sparseTileOffsets, tileGaussianIds] = FVDB_DISPATCH_KERNEL(mMeans.device(), [&]() {
return detail::ops::dispatchGaussianSparseTileIntersection<DeviceTag>(ret.perGaussian2dMean,
ret.perGaussianRadius,
ret.perGaussianDepth,
ret.activeTileMask,
ret.activeTiles,
at::nullopt,
C,
settings.tileSize,
numTilesH,
numTilesW);
});
// Use sparse 1D tile offsets - RasterizeCommonArgs detects the format from dimensions
ret.tileOffsets = sparseTileOffsets; // [num_active_tiles + 1]
ret.tileGaussianIds = tileGaussianIds; // [TOT_INTERSECTIONS]

return ret;
}

std::tuple<torch::Tensor, torch::Tensor>
GaussianSplat3d::renderCropFromProjectedGaussiansImpl(
const ProjectedGaussianSplats &projectedGaussians,
Expand Down Expand Up @@ -454,14 +589,8 @@ GaussianSplat3d::sparseRenderImpl(const JaggedTensor &pixelsToRender,
const fvdb::detail::ops::RenderSettings &settings) {
FVDB_FUNC_RANGE();

const ProjectedGaussianSplats &state =
projectGaussiansImpl(worldToCameraMatrices, projectionMatrices, settings);

const auto [activeTiles, activeTileMask, tilePixelMask, tilePixelCumsum, pixelMap] =
fvdb::detail::ops::computeSparseInfo(settings.tileSize,
state.tileOffsets.size(2),
state.tileOffsets.size(1),
pixelsToRender);
const SparseProjectedGaussianSplats &state = sparseProjectGaussiansImpl(
pixelsToRender, worldToCameraMatrices, projectionMatrices, settings);

auto rasterizeResult =
detail::autograd::RasterizeGaussiansToPixelsSparse::apply(pixelsToRender,
Expand All @@ -476,10 +605,10 @@ GaussianSplat3d::sparseRenderImpl(const JaggedTensor &pixelsToRender,
settings.tileSize,
state.tileOffsets,
state.tileGaussianIds,
activeTiles,
tilePixelMask,
tilePixelCumsum,
pixelMap,
state.activeTiles,
state.tilePixelMask,
state.tilePixelCumsum,
state.pixelMap,
false);
auto renderedPixelsJData = rasterizeResult[0];

Expand Down Expand Up @@ -514,14 +643,9 @@ GaussianSplat3d::sparseRenderNumContributingGaussiansImpl(
const torch::Tensor &projectionMatrices,
const fvdb::detail::ops::RenderSettings &settings) {
FVDB_FUNC_RANGE();
const ProjectedGaussianSplats &state =
projectGaussiansImpl(worldToCameraMatrices, projectionMatrices, settings);

const auto [activeTiles, activeTileMask, tilePixelMask, tilePixelCumsum, pixelMap] =
fvdb::detail::ops::computeSparseInfo(settings.tileSize,
state.tileOffsets.size(2),
state.tileOffsets.size(1),
pixelsToRender);
const SparseProjectedGaussianSplats &state = sparseProjectGaussiansImpl(
pixelsToRender, worldToCameraMatrices, projectionMatrices, settings);

return FVDB_DISPATCH_KERNEL_DEVICE(state.perGaussian2dMean.device(), [&]() {
return fvdb::detail::ops::dispatchGaussianSparseRasterizeNumContributingGaussians<
Expand All @@ -531,10 +655,10 @@ GaussianSplat3d::sparseRenderNumContributingGaussiansImpl(
state.tileOffsets,
state.tileGaussianIds,
pixelsToRender,
activeTiles,
tilePixelMask,
tilePixelCumsum,
pixelMap,
state.activeTiles,
state.tilePixelMask,
state.tilePixelCumsum,
state.pixelMap,
settings);
});
}
Expand Down Expand Up @@ -585,14 +709,9 @@ GaussianSplat3d::sparseRenderContributingGaussianIdsImpl(
const fvdb::detail::ops::RenderSettings &settings,
const std::optional<fvdb::JaggedTensor> &maybeNumContributingGaussians) {
FVDB_FUNC_RANGE();
const ProjectedGaussianSplats &state =
projectGaussiansImpl(worldToCameraMatrices, projectionMatrices, settings);

const auto [activeTiles, activeTileMask, tilePixelMask, tilePixelCumsum, pixelMap] =
fvdb::detail::ops::computeSparseInfo(settings.tileSize,
state.tileOffsets.size(2),
state.tileOffsets.size(1),
pixelsToRender);
const SparseProjectedGaussianSplats &state = sparseProjectGaussiansImpl(
pixelsToRender, worldToCameraMatrices, projectionMatrices, settings);

return FVDB_DISPATCH_KERNEL_DEVICE(state.perGaussian2dMean.device(), [&]() {
return fvdb::detail::ops::dispatchGaussianSparseRasterizeContributingGaussianIds<DeviceTag>(
Expand All @@ -602,10 +721,10 @@ GaussianSplat3d::sparseRenderContributingGaussianIdsImpl(
state.tileOffsets,
state.tileGaussianIds,
pixelsToRender,
activeTiles,
tilePixelMask,
tilePixelCumsum,
pixelMap,
state.activeTiles,
state.tilePixelMask,
state.tilePixelCumsum,
state.pixelMap,
settings,
maybeNumContributingGaussians);
});
Expand Down
73 changes: 23 additions & 50 deletions src/fvdb/GaussianSplat3d.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,17 @@ class GaussianSplat3d {
}
};

/// @brief A set of projected Gaussians with sparse tile intersection data for sparse rendering.
/// This struct extends ProjectedGaussianSplats with additional sparse-specific tensors.
struct SparseProjectedGaussianSplats : public ProjectedGaussianSplats {
torch::Tensor activeTiles; // [num_active_tiles] - tile IDs of active tiles
torch::Tensor activeTileMask; // [C, TH, TW] - boolean mask of active tiles
torch::Tensor tilePixelMask; // [num_active_tiles, words_per_tile] - bitmask of pixels
torch::Tensor tilePixelCumsum; // [num_active_tiles] - cumulative sum of active pixels
torch::Tensor pixelMap; // [num_active_pixels] - mapping for pixel write order
// Note: tileOffsets (inherited) is 1D [num_active_tiles + 1] in sparse mode
};

public:
/// @brief Concatenate a vector of GaussianSplat3d objects into a single GaussianSplat3d object.
/// @param splats A vector of GaussianSplat3d objects to concatenate.
Expand Down Expand Up @@ -1282,6 +1293,18 @@ class GaussianSplat3d {
const torch::Tensor &projectionMatrices,
const fvdb::detail::ops::RenderSettings &settings);

/// @brief Project Gaussians with sparse tile intersection for efficient sparse rendering.
/// @param pixelsToRender JaggedTensor of pixel coordinates to render [P1 + P2 + ..., 2]
/// @param worldToCameraMatrices [C, 4, 4] Camera-to-world matrices
/// @param projectionMatrices [C, 3, 3] Projection matrices
/// @param settings Render settings
/// @return SparseProjectedGaussianSplats containing projected Gaussians and sparse tile data
SparseProjectedGaussianSplats
sparseProjectGaussiansImpl(const JaggedTensor &pixelsToRender,
const torch::Tensor &worldToCameraMatrices,
const torch::Tensor &projectionMatrices,
const fvdb::detail::ops::RenderSettings &settings);

std::tuple<torch::Tensor, torch::Tensor> renderCropFromProjectedGaussiansImpl(
const ProjectedGaussianSplats &state,
const size_t tileSize,
Expand Down Expand Up @@ -1360,56 +1383,6 @@ class GaussianSplat3d {
const torch::Tensor &projectionMatrices,
const fvdb::detail::ops::RenderSettings &settings);

/// @brief Render the gaussian splatting scene
/// For every pixel being rendered, this function returns multiple samples in depth of
/// the gaussian IDs and multiple samples of the weighted alpha values. The number of
/// samples per pixel is determined by the sampling parameters in the settings. If
/// the size of the requested number of samples is greater than the number of
/// contributing samples for a pixel, the remaining samples' weights are filled with
/// zeros and the IDs are filled with -1. The samples are ordered front to back in
/// their depth ordering from camera.
/// @param worldToCameraMatrices [C, 4, 4]
/// @param projectionMatrices [C, 3, 3]
/// @param settings
/// @return Tuple of two tensors:
/// ids: A [B, H, W, K] tensor containing the the IDs of the top K contributors to the
/// rendered pixel for each camera
/// weights: A [B, H, W, K] tensor containing the weights of the top K contributors to the
/// rendered pixel for each camera. The weights are normalized to sum to the alpha
/// value of the final rendered pixel if the list is exahustive of all contributing
/// samples.
std::tuple<torch::Tensor, torch::Tensor>
renderTopContributingGaussianIdsImpl(const torch::Tensor &worldToCameraMatrices,
const torch::Tensor &projectionMatrices,
const fvdb::detail::ops::RenderSettings &settings);

/// @brief Sparse render the gaussian splatting scene
/// For every pixel being rendered, this function returns multiple samples in depth of
/// the gaussian IDs and multiple samples of the weighted alpha values. The number of
/// samples per pixel is determined by the sampling parameters in the settings. If
/// the size of the requested number of samples is greater than the number of
/// contributing samples for a pixel, the remaining samples' weights are filled with
/// zeros and the IDs are filled with -1. The samples are ordered front to back in
/// their depth ordering from camera.
/// @param pixelsToRender [P1 + P2 + ..., 2] JaggedTensor of pixels per camera to render.
/// @param worldToCameraMatrices [C, 4, 4]
/// @param projectionMatrices [C, 3, 3]
/// @param settings
/// @return Tuple of two tensors:
/// ids: A [P1 + P2 + ..., K] jagged tensor containing the the IDs of the top K contributors
/// to the
/// rendered pixel for each camera
/// weights: A [P1 + P2 + ..., K] jagged tensor containing the weights of the top K
/// contributors to the
/// rendered pixel for each camera. The weights are normalized to sum to the alpha
/// value of the final rendered pixel if the list is exahustive of all contributing
/// samples.
std::tuple<fvdb::JaggedTensor, fvdb::JaggedTensor>
sparseRenderTopContributingGaussianIdsImpl(const fvdb::JaggedTensor &pixelsToRender,
const torch::Tensor &worldToCameraMatrices,
const torch::Tensor &projectionMatrices,
const fvdb::detail::ops::RenderSettings &settings);

/// @brief Render the gaussian splatting scene
/// For every pixel being rendered, this function returns multiple samples in depth of
/// the gaussian IDs and multiple samples of the weighted alpha values. The samples are
Expand Down
Loading
Loading