@@ -20,8 +20,6 @@ namespace cg = cooperative_groups;
2020
2121namespace {
2222
23- // TODO(fvdb): Consider refactoring this kernel to reuse a common rasterization args struct
24- // (`RasterizeCommonArgs` or a derived struct) for consistency with the other rasterizers.
2523template <uint32_t NUM_CHANNELS> struct SharedGaussian {
2624 int32_t id; // flattened id in [0, C*N)
2725 nanovdb::math::Vec3<float > mean; // world mean
@@ -34,6 +32,7 @@ template <uint32_t NUM_CHANNELS> struct SharedGaussian {
3432template <uint32_t NUM_CHANNELS>
3533__global__ void
3634rasterizeFromWorld3DGSBackwardKernel (
35+ const RasterizeFromWorldCommonArgs commonArgs,
3736 // Gaussians
3837 const torch::PackedTensorAccessor64<float , 2 , torch::RestrictPtrTraits> means, // [N,3]
3938 const torch::PackedTensorAccessor64<float , 2 , torch::RestrictPtrTraits> quats, // [N,4]
@@ -51,20 +50,6 @@ rasterizeFromWorld3DGSBackwardKernel(
5150 const int64_t numDistCoeffs,
5251 const RollingShutterType rollingShutterType,
5352 const CameraModel cameraModel,
54- // Settings
55- const uint32_t imageWidth,
56- const uint32_t imageHeight,
57- const uint32_t imageOriginW,
58- const uint32_t imageOriginH,
59- const uint32_t tileSize,
60- const uint32_t tileExtentW,
61- const uint32_t tileExtentH,
62- // Intersections
63- const torch::PackedTensorAccessor64<int32_t , 3 , torch::RestrictPtrTraits>
64- tileOffsets, // [C, tileExtentH, tileExtentW]
65- const torch::PackedTensorAccessor64<int32_t , 1 , torch::RestrictPtrTraits>
66- tileGaussianIds, // [n_isects]
67- const int32_t totalIntersections,
6853 // Forward outputs
6954 const torch::PackedTensorAccessor64<float , 4 , torch::RestrictPtrTraits>
7055 renderedAlphas, // [C,H,W,1]
@@ -74,10 +59,6 @@ rasterizeFromWorld3DGSBackwardKernel(
7459 dLossDRenderedFeatures, // [C,H,W,D]
7560 const torch::PackedTensorAccessor64<float , 4 , torch::RestrictPtrTraits>
7661 dLossDRenderedAlphas, // [C,H,W,1]
77- // Backgrounds
78- const float *__restrict__ backgrounds, // [C,D] or nullptr
79- // Optional tile masks
80- const bool *__restrict__ masks, // [C, tileExtentH, tileExtentW] or nullptr
8162 // Outputs (grads)
8263 float *__restrict__ dMeans, // [N,3]
8364 float *__restrict__ dQuats, // [N,4]
@@ -88,24 +69,16 @@ rasterizeFromWorld3DGSBackwardKernel(
8869 auto block = cg::this_thread_block ();
8970 const uint32_t blockSize = blockDim .x * blockDim .y ;
9071
91- const uint32_t globalLinearBlock = blockIdx .x ;
92- const uint32_t camId = globalLinearBlock / (tileExtentH * tileExtentW);
93- const uint32_t tileLinear = globalLinearBlock - camId * (tileExtentH * tileExtentW);
94- const uint32_t tileRow = tileLinear / tileExtentW;
95- const uint32_t tileCol = tileLinear - tileRow * tileExtentW;
96-
97- const uint32_t row = tileRow * tileSize + threadIdx .y ;
98- const uint32_t col = tileCol * tileSize + threadIdx .x ;
99- const bool inside = (row < imageHeight && col < imageWidth);
72+ uint32_t camId, tileRow, tileCol, row, col;
73+ commonArgs.denseCoordinates (camId, tileRow, tileCol, row, col);
74+ const bool inside = (row < commonArgs.imageHeight && col < commonArgs.imageWidth );
10075
10176 // Parity with classic rasterizer: masked tiles contribute nothing.
10277 //
10378 // IMPORTANT: this kernel uses block-level barriers later (`block.sync`). Any early return must
10479 // be taken by *all* threads in the block, otherwise edge tiles can deadlock when some threads
10580 // are `!inside`. So we make the return block-wide.
106- const bool tileMasked =
107- (masks != nullptr ) &&
108- (!masks[camId * tileExtentH * tileExtentW + tileRow * tileExtentW + tileCol]);
81+ const bool tileMasked = commonArgs.tileMasked (camId, tileRow, tileCol);
10982 if (tileMasked) {
11083 return ;
11184 }
@@ -123,10 +96,10 @@ rasterizeFromWorld3DGSBackwardKernel(
12396
12497 const nanovdb::math::Ray<float > ray = pixelToWorldRay<float >(row,
12598 col,
126- imageWidth,
127- imageHeight,
128- imageOriginW,
129- imageOriginH,
99+ commonArgs. imageWidth ,
100+ commonArgs. imageHeight ,
101+ commonArgs. imageOriginW ,
102+ commonArgs. imageOriginH ,
130103 R_wc_start,
131104 t_wc_start,
132105 R_wc_end,
@@ -144,20 +117,7 @@ rasterizeFromWorld3DGSBackwardKernel(
144117 const bool done = inside && rayValid;
145118
146119 // Gaussian range for this tile.
147- const int32_t rangeStart = tileOffsets[camId][tileRow][tileCol];
148- int32_t rangeEnd = 0 ;
149- if ((camId == (uint32_t )(features.size (0 ) - 1 )) && (tileRow == tileExtentH - 1 ) &&
150- (tileCol == tileExtentW - 1 )) {
151- rangeEnd = totalIntersections;
152- } else if (tileCol + 1 < tileExtentW) {
153- rangeEnd = tileOffsets[camId][tileRow][tileCol + 1 ];
154- } else {
155- if (tileRow + 1 < tileExtentH) {
156- rangeEnd = tileOffsets[camId][tileRow + 1 ][0 ];
157- } else {
158- rangeEnd = tileOffsets[camId + 1 ][0 ][0 ];
159- }
160- }
120+ const auto [rangeStart, rangeEnd] = commonArgs.tileGaussianRange (camId, tileRow, tileCol);
161121
162122 // If the tile has no intersections, there is nothing to do. This must be a block-wide return.
163123 if (rangeEnd <= rangeStart) {
@@ -218,7 +178,7 @@ rasterizeFromWorld3DGSBackwardKernel(
218178 const int32_t idx = batchEnd - (int32_t )threadRank;
219179
220180 if (idx >= rangeStart) {
221- const int32_t flatId = tileGaussianIds[idx];
181+ const int32_t flatId = commonArgs. tileGaussianIds [idx];
222182 idBatch[threadRank] = flatId;
223183 const int32_t gid = flatId % (int32_t )means.size (0 );
224184 const int32_t cid = flatId / (int32_t )means.size (0 );
@@ -322,11 +282,11 @@ rasterizeFromWorld3DGSBackwardKernel(
322282
323283 v_alpha += T_final * ra * v_render_a;
324284
325- if (backgrounds != nullptr ) {
285+ if (commonArgs. backgrounds != nullptr ) {
326286 float accum = 0 .f ;
327287#pragma unroll
328288 for (uint32_t k = 0 ; k < NUM_CHANNELS; ++k) {
329- accum += backgrounds[ camId * NUM_CHANNELS + k] * v_render_c[k];
289+ accum += commonArgs. backgroundValue ( camId, k) * v_render_c[k];
330290 }
331291 v_alpha += -T_final * ra * accum;
332292 }
@@ -469,17 +429,36 @@ launchBackward(const torch::Tensor &means,
469429 const uint32_t tileExtentH = (imageHeight + tileSize - 1 ) / tileSize;
470430 const dim3 blockDim (tileSize, tileSize, 1 );
471431 const dim3 gridDim (C * tileExtentH * tileExtentW, 1 , 1 );
472- const int32_t totalIntersections = ( int32_t ) tileGaussianIds.size (0 );
432+ const int32_t totalIntersections = static_cast < int32_t >( tileGaussianIds.size (0 ) );
473433 const int64_t numDistCoeffs = distortionCoeffs.size (1 );
474434
435+ RasterizeFromWorldCommonArgs commonArgs{
436+ static_cast <uint32_t >(C),
437+ imageWidth,
438+ imageHeight,
439+ imageOriginW,
440+ imageOriginH,
441+ tileSize,
442+ tileExtentW,
443+ tileExtentH,
444+ NUM_CHANNELS,
445+ totalIntersections,
446+ tileOffsets.packed_accessor64 <int32_t , 3 , torch::RestrictPtrTraits>(),
447+ tileGaussianIds.packed_accessor64 <int32_t , 1 , torch::RestrictPtrTraits>(),
448+ nullptr ,
449+ nullptr };
450+
475451 const PreparedRasterOptionalInputs opt = prepareRasterOptionalInputs (
476452 features, C, tileExtentH, tileExtentW, (int64_t )NUM_CHANNELS, backgrounds, masks);
453+ commonArgs.backgrounds = opt.backgrounds ;
454+ commonArgs.masks = opt.masks ;
477455
478456 const size_t blockSize = (size_t )tileSize * (size_t )tileSize;
479457 const size_t sharedMem = blockSize * (sizeof (int32_t ) + sizeof (SharedGaussian<NUM_CHANNELS>));
480458
481459 auto stream = at::cuda::getDefaultCUDAStream ();
482460 rasterizeFromWorld3DGSBackwardKernel<NUM_CHANNELS><<<gridDim , blockDim , sharedMem, stream>>> (
461+ commonArgs,
483462 means.packed_accessor64 <float , 2 , torch::RestrictPtrTraits>(),
484463 quats.packed_accessor64 <float , 2 , torch::RestrictPtrTraits>(),
485464 logScales.packed_accessor64 <float , 2 , torch::RestrictPtrTraits>(),
@@ -492,22 +471,10 @@ launchBackward(const torch::Tensor &means,
492471 numDistCoeffs,
493472 rollingShutterType,
494473 cameraModel,
495- imageWidth,
496- imageHeight,
497- imageOriginW,
498- imageOriginH,
499- tileSize,
500- tileExtentW,
501- tileExtentH,
502- tileOffsets.packed_accessor64 <int32_t , 3 , torch::RestrictPtrTraits>(),
503- tileGaussianIds.packed_accessor64 <int32_t , 1 , torch::RestrictPtrTraits>(),
504- totalIntersections,
505474 renderedAlphas.packed_accessor64 <float , 4 , torch::RestrictPtrTraits>(),
506475 lastIds.packed_accessor64 <int32_t , 3 , torch::RestrictPtrTraits>(),
507476 dLossDRenderedFeatures.packed_accessor64 <float , 4 , torch::RestrictPtrTraits>(),
508477 dLossDRenderedAlphas.packed_accessor64 <float , 4 , torch::RestrictPtrTraits>(),
509- opt.backgrounds ,
510- opt.masks ,
511478 dMeans.data_ptr <float >(),
512479 dQuats.data_ptr <float >(),
513480 dLogScales.data_ptr <float >(),
0 commit comments