Skip to content

Commit 6721e4f

Browse files
refactor(gsplat): introduce shared from-world rasterize args
Add a RasterizeFromWorldCommonArgs helper used by forward/backward from-world kernels to centralize dense tile coordinates, tile range lookup, and optional background/mask access while preserving existing rendering and gradient behavior. Signed-off-by: Francis Williams <francis@fwilliams.info> Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 6ea83c2 commit 6721e4f

File tree

3 files changed

+156
-141
lines changed

3 files changed

+156
-141
lines changed

src/fvdb/detail/ops/gsplat/GaussianRasterizeFromWorld.cuh

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,104 @@
1010
#include <fvdb/detail/ops/gsplat/GaussianRigidTransform.cuh>
1111
#include <fvdb/detail/ops/gsplat/GaussianRollingShutter.cuh>
1212
#include <fvdb/detail/ops/gsplat/GaussianUtils.cuh>
13+
#include <fvdb/detail/utils/AccessorHelpers.cuh>
1314

1415
#include <nanovdb/math/Math.h>
1516
#include <nanovdb/math/Ray.h>
1617

18+
#include <cuda/std/tuple>
19+
1720
#include <cstdint>
1821

1922
namespace fvdb::detail::ops {
2023

2124
/// Opacity threshold used by 3DGS alpha compositing.
2225
constexpr __device__ float kAlphaThreshold = 0.999f;
2326

27+
/// Common dense-tile rasterization arguments shared by from-world forward/backward kernels.
28+
struct RasterizeFromWorldCommonArgs {
29+
using TileOffsetsAccessor = fvdb::TorchRAcc64<int32_t, 3>;
30+
using TileGaussianIdsAccessor = fvdb::TorchRAcc64<int32_t, 1>;
31+
32+
uint32_t numCameras;
33+
uint32_t imageWidth;
34+
uint32_t imageHeight;
35+
uint32_t imageOriginW;
36+
uint32_t imageOriginH;
37+
uint32_t tileSize;
38+
uint32_t numTilesW;
39+
uint32_t numTilesH;
40+
uint32_t numChannels;
41+
int32_t totalIntersections;
42+
43+
TileOffsetsAccessor tileOffsets; // [C, TH, TW]
44+
TileGaussianIdsAccessor tileGaussianIds; // [n_isects]
45+
const float *backgrounds; // [C, D] or nullptr
46+
const bool *masks; // [C, TH, TW] or nullptr
47+
48+
inline __device__ void
49+
denseCoordinates(uint32_t &cameraId,
50+
uint32_t &tileRow,
51+
uint32_t &tileCol,
52+
uint32_t &row,
53+
uint32_t &col) const {
54+
const uint32_t linearBlock = blockIdx.x;
55+
cameraId = linearBlock / (numTilesH * numTilesW);
56+
const uint32_t tileLinear = linearBlock - cameraId * (numTilesH * numTilesW);
57+
tileRow = tileLinear / numTilesW;
58+
tileCol = tileLinear - tileRow * numTilesW;
59+
row = tileRow * tileSize + threadIdx.y;
60+
col = tileCol * tileSize + threadIdx.x;
61+
}
62+
63+
inline __device__ uint32_t
64+
tileId(const uint32_t cameraId, const uint32_t tileRow, const uint32_t tileCol) const {
65+
return cameraId * numTilesH * numTilesW + tileRow * numTilesW + tileCol;
66+
}
67+
68+
inline __device__ bool
69+
tileMasked(const uint32_t cameraId, const uint32_t tileRow, const uint32_t tileCol) const {
70+
return (masks != nullptr) && (!masks[tileId(cameraId, tileRow, tileCol)]);
71+
}
72+
73+
inline __device__ cuda::std::tuple<int32_t, int32_t>
74+
tileGaussianRange(const uint32_t cameraId,
75+
const uint32_t tileRow,
76+
const uint32_t tileCol) const {
77+
const int32_t rangeStart = tileOffsets[cameraId][tileRow][tileCol];
78+
79+
const int32_t rangeEnd =
80+
((cameraId == numCameras - 1) && (tileRow == numTilesH - 1) &&
81+
(tileCol == numTilesW - 1))
82+
? totalIntersections
83+
: ((tileCol + 1 < numTilesW)
84+
? tileOffsets[cameraId][tileRow][tileCol + 1]
85+
: ((tileRow + 1 < numTilesH) ? tileOffsets[cameraId][tileRow + 1][0]
86+
: tileOffsets[cameraId + 1][0][0]));
87+
return {rangeStart, rangeEnd};
88+
}
89+
90+
inline __device__ uint32_t
91+
pixelId(const uint32_t row, const uint32_t col) const {
92+
return row * imageWidth + col;
93+
}
94+
95+
inline __device__ uint32_t
96+
outputPixelBase(const uint32_t cameraId, const uint32_t pixId) const {
97+
return cameraId * imageHeight * imageWidth + pixId;
98+
}
99+
100+
inline __device__ uint32_t
101+
outputFeatureBase(const uint32_t cameraId, const uint32_t pixId) const {
102+
return outputPixelBase(cameraId, pixId) * numChannels;
103+
}
104+
105+
inline __device__ float
106+
backgroundValue(const uint32_t cameraId, const uint32_t channelId) const {
107+
return (backgrounds != nullptr) ? backgrounds[cameraId * numChannels + channelId] : 0.0f;
108+
}
109+
};
110+
24111
template <typename T>
25112
inline __device__ nanovdb::math::Vec3<T>
26113
normalizeSafe(const nanovdb::math::Vec3<T> &v) {

src/fvdb/detail/ops/gsplat/GaussianRasterizeFromWorldBackward.cu

Lines changed: 33 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ namespace cg = cooperative_groups;
2020

2121
namespace {
2222

23-
// TODO(fvdb): Consider refactoring this kernel to reuse a common rasterization args struct
24-
// (`RasterizeCommonArgs` or a derived struct) for consistency with the other rasterizers.
2523
template <uint32_t NUM_CHANNELS> struct SharedGaussian {
2624
int32_t id; // flattened id in [0, C*N)
2725
nanovdb::math::Vec3<float> mean; // world mean
@@ -34,6 +32,7 @@ template <uint32_t NUM_CHANNELS> struct SharedGaussian {
3432
template <uint32_t NUM_CHANNELS>
3533
__global__ void
3634
rasterizeFromWorld3DGSBackwardKernel(
35+
const RasterizeFromWorldCommonArgs commonArgs,
3736
// Gaussians
3837
const torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> means, // [N,3]
3938
const torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> quats, // [N,4]
@@ -51,20 +50,6 @@ rasterizeFromWorld3DGSBackwardKernel(
5150
const int64_t numDistCoeffs,
5251
const RollingShutterType rollingShutterType,
5352
const CameraModel cameraModel,
54-
// Settings
55-
const uint32_t imageWidth,
56-
const uint32_t imageHeight,
57-
const uint32_t imageOriginW,
58-
const uint32_t imageOriginH,
59-
const uint32_t tileSize,
60-
const uint32_t tileExtentW,
61-
const uint32_t tileExtentH,
62-
// Intersections
63-
const torch::PackedTensorAccessor64<int32_t, 3, torch::RestrictPtrTraits>
64-
tileOffsets, // [C, tileExtentH, tileExtentW]
65-
const torch::PackedTensorAccessor64<int32_t, 1, torch::RestrictPtrTraits>
66-
tileGaussianIds, // [n_isects]
67-
const int32_t totalIntersections,
6853
// Forward outputs
6954
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits>
7055
renderedAlphas, // [C,H,W,1]
@@ -74,10 +59,6 @@ rasterizeFromWorld3DGSBackwardKernel(
7459
dLossDRenderedFeatures, // [C,H,W,D]
7560
const torch::PackedTensorAccessor64<float, 4, torch::RestrictPtrTraits>
7661
dLossDRenderedAlphas, // [C,H,W,1]
77-
// Backgrounds
78-
const float *__restrict__ backgrounds, // [C,D] or nullptr
79-
// Optional tile masks
80-
const bool *__restrict__ masks, // [C, tileExtentH, tileExtentW] or nullptr
8162
// Outputs (grads)
8263
float *__restrict__ dMeans, // [N,3]
8364
float *__restrict__ dQuats, // [N,4]
@@ -88,24 +69,16 @@ rasterizeFromWorld3DGSBackwardKernel(
8869
auto block = cg::this_thread_block();
8970
const uint32_t blockSize = blockDim.x * blockDim.y;
9071

91-
const uint32_t globalLinearBlock = blockIdx.x;
92-
const uint32_t camId = globalLinearBlock / (tileExtentH * tileExtentW);
93-
const uint32_t tileLinear = globalLinearBlock - camId * (tileExtentH * tileExtentW);
94-
const uint32_t tileRow = tileLinear / tileExtentW;
95-
const uint32_t tileCol = tileLinear - tileRow * tileExtentW;
96-
97-
const uint32_t row = tileRow * tileSize + threadIdx.y;
98-
const uint32_t col = tileCol * tileSize + threadIdx.x;
99-
const bool inside = (row < imageHeight && col < imageWidth);
72+
uint32_t camId, tileRow, tileCol, row, col;
73+
commonArgs.denseCoordinates(camId, tileRow, tileCol, row, col);
74+
const bool inside = (row < commonArgs.imageHeight && col < commonArgs.imageWidth);
10075

10176
// Parity with classic rasterizer: masked tiles contribute nothing.
10277
//
10378
// IMPORTANT: this kernel uses block-level barriers later (`block.sync`). Any early return must
10479
// be taken by *all* threads in the block, otherwise edge tiles can deadlock when some threads
10580
// are `!inside`. So we make the return block-wide.
106-
const bool tileMasked =
107-
(masks != nullptr) &&
108-
(!masks[camId * tileExtentH * tileExtentW + tileRow * tileExtentW + tileCol]);
81+
const bool tileMasked = commonArgs.tileMasked(camId, tileRow, tileCol);
10982
if (tileMasked) {
11083
return;
11184
}
@@ -123,10 +96,10 @@ rasterizeFromWorld3DGSBackwardKernel(
12396

12497
const nanovdb::math::Ray<float> ray = pixelToWorldRay<float>(row,
12598
col,
126-
imageWidth,
127-
imageHeight,
128-
imageOriginW,
129-
imageOriginH,
99+
commonArgs.imageWidth,
100+
commonArgs.imageHeight,
101+
commonArgs.imageOriginW,
102+
commonArgs.imageOriginH,
130103
R_wc_start,
131104
t_wc_start,
132105
R_wc_end,
@@ -144,20 +117,7 @@ rasterizeFromWorld3DGSBackwardKernel(
144117
const bool done = inside && rayValid;
145118

146119
// Gaussian range for this tile.
147-
const int32_t rangeStart = tileOffsets[camId][tileRow][tileCol];
148-
int32_t rangeEnd = 0;
149-
if ((camId == (uint32_t)(features.size(0) - 1)) && (tileRow == tileExtentH - 1) &&
150-
(tileCol == tileExtentW - 1)) {
151-
rangeEnd = totalIntersections;
152-
} else if (tileCol + 1 < tileExtentW) {
153-
rangeEnd = tileOffsets[camId][tileRow][tileCol + 1];
154-
} else {
155-
if (tileRow + 1 < tileExtentH) {
156-
rangeEnd = tileOffsets[camId][tileRow + 1][0];
157-
} else {
158-
rangeEnd = tileOffsets[camId + 1][0][0];
159-
}
160-
}
120+
const auto [rangeStart, rangeEnd] = commonArgs.tileGaussianRange(camId, tileRow, tileCol);
161121

162122
// If the tile has no intersections, there is nothing to do. This must be a block-wide return.
163123
if (rangeEnd <= rangeStart) {
@@ -218,7 +178,7 @@ rasterizeFromWorld3DGSBackwardKernel(
218178
const int32_t idx = batchEnd - (int32_t)threadRank;
219179

220180
if (idx >= rangeStart) {
221-
const int32_t flatId = tileGaussianIds[idx];
181+
const int32_t flatId = commonArgs.tileGaussianIds[idx];
222182
idBatch[threadRank] = flatId;
223183
const int32_t gid = flatId % (int32_t)means.size(0);
224184
const int32_t cid = flatId / (int32_t)means.size(0);
@@ -322,11 +282,11 @@ rasterizeFromWorld3DGSBackwardKernel(
322282

323283
v_alpha += T_final * ra * v_render_a;
324284

325-
if (backgrounds != nullptr) {
285+
if (commonArgs.backgrounds != nullptr) {
326286
float accum = 0.f;
327287
#pragma unroll
328288
for (uint32_t k = 0; k < NUM_CHANNELS; ++k) {
329-
accum += backgrounds[camId * NUM_CHANNELS + k] * v_render_c[k];
289+
accum += commonArgs.backgroundValue(camId, k) * v_render_c[k];
330290
}
331291
v_alpha += -T_final * ra * accum;
332292
}
@@ -469,17 +429,36 @@ launchBackward(const torch::Tensor &means,
469429
const uint32_t tileExtentH = (imageHeight + tileSize - 1) / tileSize;
470430
const dim3 blockDim(tileSize, tileSize, 1);
471431
const dim3 gridDim(C * tileExtentH * tileExtentW, 1, 1);
472-
const int32_t totalIntersections = (int32_t)tileGaussianIds.size(0);
432+
const int32_t totalIntersections = static_cast<int32_t>(tileGaussianIds.size(0));
473433
const int64_t numDistCoeffs = distortionCoeffs.size(1);
474434

435+
RasterizeFromWorldCommonArgs commonArgs{
436+
static_cast<uint32_t>(C),
437+
imageWidth,
438+
imageHeight,
439+
imageOriginW,
440+
imageOriginH,
441+
tileSize,
442+
tileExtentW,
443+
tileExtentH,
444+
NUM_CHANNELS,
445+
totalIntersections,
446+
tileOffsets.packed_accessor64<int32_t, 3, torch::RestrictPtrTraits>(),
447+
tileGaussianIds.packed_accessor64<int32_t, 1, torch::RestrictPtrTraits>(),
448+
nullptr,
449+
nullptr};
450+
475451
const PreparedRasterOptionalInputs opt = prepareRasterOptionalInputs(
476452
features, C, tileExtentH, tileExtentW, (int64_t)NUM_CHANNELS, backgrounds, masks);
453+
commonArgs.backgrounds = opt.backgrounds;
454+
commonArgs.masks = opt.masks;
477455

478456
const size_t blockSize = (size_t)tileSize * (size_t)tileSize;
479457
const size_t sharedMem = blockSize * (sizeof(int32_t) + sizeof(SharedGaussian<NUM_CHANNELS>));
480458

481459
auto stream = at::cuda::getDefaultCUDAStream();
482460
rasterizeFromWorld3DGSBackwardKernel<NUM_CHANNELS><<<gridDim, blockDim, sharedMem, stream>>>(
461+
commonArgs,
483462
means.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
484463
quats.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
485464
logScales.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
@@ -492,22 +471,10 @@ launchBackward(const torch::Tensor &means,
492471
numDistCoeffs,
493472
rollingShutterType,
494473
cameraModel,
495-
imageWidth,
496-
imageHeight,
497-
imageOriginW,
498-
imageOriginH,
499-
tileSize,
500-
tileExtentW,
501-
tileExtentH,
502-
tileOffsets.packed_accessor64<int32_t, 3, torch::RestrictPtrTraits>(),
503-
tileGaussianIds.packed_accessor64<int32_t, 1, torch::RestrictPtrTraits>(),
504-
totalIntersections,
505474
renderedAlphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
506475
lastIds.packed_accessor64<int32_t, 3, torch::RestrictPtrTraits>(),
507476
dLossDRenderedFeatures.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
508477
dLossDRenderedAlphas.packed_accessor64<float, 4, torch::RestrictPtrTraits>(),
509-
opt.backgrounds,
510-
opt.masks,
511478
dMeans.data_ptr<float>(),
512479
dQuats.data_ptr<float>(),
513480
dLogScales.data_ptr<float>(),

0 commit comments

Comments
 (0)