diff --git a/src/fvdb/detail/ops/gsplat/GaussianRasterizeBackward.cu b/src/fvdb/detail/ops/gsplat/GaussianRasterizeBackward.cu index 3ea5ad9c..2c760391 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianRasterizeBackward.cu +++ b/src/fvdb/detail/ops/gsplat/GaussianRasterizeBackward.cu @@ -701,8 +701,14 @@ struct RasterizeBackwardArgs { const int32_t g = commonArgs.mTileGaussianIds[idx]; // Gaussian index in [C * N] or [nnz] sharedGaussians[tidx] = commonArgs.getGaussian(g); - ScalarType *feature = &sharedGaussianFeatures[tidx * NUM_SHARED_CHANNELS]; - fetchGaussianFeatureIntoSharedMemory(g, channelStart, numChannels, feature); + // Only load features if the Gaussian could be valid for any pixel. + // If opacity < 1/255, alpha < 1/255 for all pixels regardless of + // position, so gaussianIsValid will be false and features are never + // read in the rendering loop. + if (sharedGaussians[tidx].opacity >= 1.f / 255.f) { + ScalarType *feature = &sharedGaussianFeatures[tidx * NUM_SHARED_CHANNELS]; + fetchGaussianFeatureIntoSharedMemory(g, channelStart, numChannels, feature); + } } // Sync threads so all gaussians for this batch are loaded in shared memory diff --git a/src/fvdb/detail/ops/gsplat/GaussianRasterizeForward.cu b/src/fvdb/detail/ops/gsplat/GaussianRasterizeForward.cu index 75dbde1a..b0984275 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianRasterizeForward.cu +++ b/src/fvdb/detail/ops/gsplat/GaussianRasterizeForward.cu @@ -105,6 +105,29 @@ template struct Ras } } + // Fetch the features for a Gaussian into shared memory + inline __device__ void + fetchGaussianFeatureIntoSharedMemory(const int32_t g, ScalarType *outFeatures) { + if constexpr (IS_PACKED) { + const auto featureAccessor = commonArgs.mFeatures[g]; +#pragma unroll NUM_CHANNELS + for (uint32_t k = 0; k < NUM_CHANNELS; ++k) { + outFeatures[k] = featureAccessor[k]; + } + } else { + // colors: [C, N, NUM_CHANNELS] + // colors[c, n, k] = [c * N * NUM_CHANNELS + n * NUM_CHANNELS + k] + // g = c * N + n + const int32_t cid = g / commonArgs.mNumGaussiansPerCamera; + const int32_t gid = g % commonArgs.mNumGaussiansPerCamera; + const auto featureAccessor = commonArgs.mFeatures[cid][gid]; +#pragma unroll NUM_CHANNELS + for (auto k = 0; k < NUM_CHANNELS; ++k) { + outFeatures[k] = featureAccessor[k]; + } + } + } + /// @brief Volume render a tile of Gaussians /// @param cameraId The ID of the camera /// @param firstGaussianIdInBlock The first Gaussian ID in the block @@ -124,7 +147,10 @@ template struct Ras const bool pixelIsActive, const uint32_t activePixelIndex) { alignas(Gaussian2D) extern __shared__ char s[]; + auto *sharedGaussians = reinterpret_cast *>(s); // [blockSize] + ScalarType *sharedGaussianFeatures = + reinterpret_cast(&sharedGaussians[blockSize]); // [blockSize] // NOTE: The accumulated transmittance is used in the backward pass, and // since it's a @@ -157,7 +183,7 @@ template struct Ras for (uint32_t b = 0; b < numBatches; ++b) { // Sync threads before we start integrating the next batch // If all threads are done, we can break early - if (__syncthreads_count(done) == blockSize) { + if (__syncthreads_and(done)) { break; } @@ -169,6 +195,14 @@ template struct Ras const int32_t g = commonArgs.mTileGaussianIds[idx]; // which gaussian we're rendering sharedGaussians[tidx] = commonArgs.getGaussian(g); + // Only load features if the Gaussian could be valid for any pixel. + // If opacity < 1/255, alpha < 1/255 for all pixels regardless of + // position, so gaussianIsValid will be false and features are never + // read in the rendering loop. + if (sharedGaussians[tidx].opacity >= 1.f / 255.f) { + ScalarType *feature = &sharedGaussianFeatures[tidx * NUM_CHANNELS]; + fetchGaussianFeatureIntoSharedMemory(g, feature); + } } // Sync threads so all gaussians for this batch are loaded in shared @@ -194,19 +228,10 @@ template struct Ras break; } - const ScalarType vis = alpha * accumTransmittance; - const auto featureAccessor = [&]() { - if constexpr (IS_PACKED) { - return commonArgs.mFeatures[gaussian.id]; - } else { - const int32_t cid = gaussian.id / commonArgs.mNumGaussiansPerCamera; - const int32_t gid = gaussian.id % commonArgs.mNumGaussiansPerCamera; - return commonArgs.mFeatures[cid][gid]; - } - }(); + const ScalarType vis = alpha * accumTransmittance; PRAGMA_UNROLL for (uint32_t k = 0; k < NUM_CHANNELS; ++k) { - pixOut[k] += featureAccessor[k] * vis; + pixOut[k] += sharedGaussianFeatures[t * NUM_CHANNELS + k] * vis; } curIdx = batchStart + t; @@ -277,6 +302,9 @@ rasterizeGaussiansForward(RasterizeForwardArgs -size_t -getSharedMemRequirements(const size_t tileSize) { - return tileSize * tileSize * sizeof(Gaussian2D); +constexpr size_t +getSharedMemRequirements(const size_t numChannels, const size_t tileSize) { + return tileSize * tileSize * + (sizeof(Gaussian2D) + numChannels * sizeof(ScalarType)); } template @@ -382,7 +411,7 @@ launchRasterizeForwardKernel( const at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); // Thread blocks cooperatively cache a tile of Gaussians in shared memory - const uint32_t sharedMem = getSharedMemRequirements(tileSize); + const uint32_t sharedMem = getSharedMemRequirements(NUM_CHANNELS, tileSize); // TODO: an optimization can be done by passing the actual number of // channels into the kernel functions and avoid necessary global memory @@ -495,9 +524,10 @@ launchRasterizeForwardKernels( uint32_t deviceTileOffset, deviceTileCount; std::tie(deviceTileOffset, deviceTileCount) = deviceChunk(tileCount, deviceId); - uint32_t cameraOffset = deviceTileOffset / (tileCount / C); + uint32_t cameraOffset = deviceTileOffset / (tileExtentH * tileExtentW); uint32_t cameraCount = - cuda::ceil_div(deviceTileOffset + deviceTileCount, tileCount / C) - cameraOffset; + cuda::ceil_div(deviceTileOffset + deviceTileCount, (tileExtentH * tileExtentW)) - + cameraOffset; if (deviceTileCount) { auto reshapedAlphas = outAlphas.jdata().view({C, renderWindow.height, renderWindow.width, 1}); @@ -548,7 +578,7 @@ launchRasterizeForwardKernels( pixelMap); // Thread blocks cooperatively cache a tile of Gaussians in shared memory - const uint32_t sharedMem = getSharedMemRequirements(tileSize); + const uint32_t sharedMem = getSharedMemRequirements(NUM_CHANNELS, tileSize); // TODO: an optimization can be done by passing the actual number of // channels into the kernel functions and avoid necessary global memory