Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/fvdb/detail/ops/gsplat/GaussianRasterizeBackward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -701,8 +701,14 @@ struct RasterizeBackwardArgs {
const int32_t g =
commonArgs.mTileGaussianIds[idx]; // Gaussian index in [C * N] or [nnz]
sharedGaussians[tidx] = commonArgs.getGaussian(g);
ScalarType *feature = &sharedGaussianFeatures[tidx * NUM_SHARED_CHANNELS];
fetchGaussianFeatureIntoSharedMemory(g, channelStart, numChannels, feature);
// Only load features if the Gaussian could be valid for any pixel.
// If opacity < 1/255, alpha < 1/255 for all pixels regardless of
// position, so gaussianIsValid will be false and features are never
// read in the rendering loop.
if (sharedGaussians[tidx].opacity >= 1.f / 255.f) {
ScalarType *feature = &sharedGaussianFeatures[tidx * NUM_SHARED_CHANNELS];
fetchGaussianFeatureIntoSharedMemory(g, channelStart, numChannels, feature);
}
}

// Sync threads so all gaussians for this batch are loaded in shared memory
Expand Down
68 changes: 49 additions & 19 deletions src/fvdb/detail/ops/gsplat/GaussianRasterizeForward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,29 @@ template <typename ScalarType, uint32_t NUM_CHANNELS, bool IS_PACKED> struct Ras
}
}

// Fetch the features for a Gaussian into shared memory
inline __device__ void
fetchGaussianFeatureIntoSharedMemory(const int32_t g, ScalarType *outFeatures) {
if constexpr (IS_PACKED) {
const auto featureAccessor = commonArgs.mFeatures[g];
#pragma unroll NUM_CHANNELS
for (uint32_t k = 0; k < NUM_CHANNELS; ++k) {
outFeatures[k] = featureAccessor[k];
}
} else {
// colors: [C, N, NUM_CHANNELS]
// colors[c, n, k] = [c * N * NUM_CHANNELS + n * NUM_CHANNELS + k]
// g = c * N + n
const int32_t cid = g / commonArgs.mNumGaussiansPerCamera;
const int32_t gid = g % commonArgs.mNumGaussiansPerCamera;
const auto featureAccessor = commonArgs.mFeatures[cid][gid];
#pragma unroll NUM_CHANNELS
for (auto k = 0; k < NUM_CHANNELS; ++k) {
outFeatures[k] = featureAccessor[k];
}
}
}

/// @brief Volume render a tile of Gaussians
/// @param cameraId The ID of the camera
/// @param firstGaussianIdInBlock The first Gaussian ID in the block
Expand All @@ -124,7 +147,10 @@ template <typename ScalarType, uint32_t NUM_CHANNELS, bool IS_PACKED> struct Ras
const bool pixelIsActive,
const uint32_t activePixelIndex) {
alignas(Gaussian2D<ScalarType>) extern __shared__ char s[];

auto *sharedGaussians = reinterpret_cast<Gaussian2D<ScalarType> *>(s); // [blockSize]
ScalarType *sharedGaussianFeatures =
reinterpret_cast<ScalarType *>(&sharedGaussians[blockSize]); // [blockSize]

// NOTE: The accumulated transmittance is used in the backward pass, and
// since it's a
Expand Down Expand Up @@ -157,7 +183,7 @@ template <typename ScalarType, uint32_t NUM_CHANNELS, bool IS_PACKED> struct Ras
for (uint32_t b = 0; b < numBatches; ++b) {
// Sync threads before we start integrating the next batch
// If all threads are done, we can break early
if (__syncthreads_count(done) == blockSize) {
if (__syncthreads_and(done)) {
break;
}

Expand All @@ -169,6 +195,14 @@ template <typename ScalarType, uint32_t NUM_CHANNELS, bool IS_PACKED> struct Ras
const int32_t g =
commonArgs.mTileGaussianIds[idx]; // which gaussian we're rendering
sharedGaussians[tidx] = commonArgs.getGaussian(g);
// Only load features if the Gaussian could be valid for any pixel.
// If opacity < 1/255, alpha < 1/255 for all pixels regardless of
// position, so gaussianIsValid will be false and features are never
// read in the rendering loop.
if (sharedGaussians[tidx].opacity >= 1.f / 255.f) {
ScalarType *feature = &sharedGaussianFeatures[tidx * NUM_CHANNELS];
fetchGaussianFeatureIntoSharedMemory(g, feature);
}
}

// Sync threads so all gaussians for this batch are loaded in shared
Expand All @@ -194,19 +228,10 @@ template <typename ScalarType, uint32_t NUM_CHANNELS, bool IS_PACKED> struct Ras
break;
}

const ScalarType vis = alpha * accumTransmittance;
const auto featureAccessor = [&]() {
if constexpr (IS_PACKED) {
return commonArgs.mFeatures[gaussian.id];
} else {
const int32_t cid = gaussian.id / commonArgs.mNumGaussiansPerCamera;
const int32_t gid = gaussian.id % commonArgs.mNumGaussiansPerCamera;
return commonArgs.mFeatures[cid][gid];
}
}();
const ScalarType vis = alpha * accumTransmittance;
PRAGMA_UNROLL
for (uint32_t k = 0; k < NUM_CHANNELS; ++k) {
pixOut[k] += featureAccessor[k] * vis;
pixOut[k] += sharedGaussianFeatures[t * NUM_CHANNELS + k] * vis;
}

curIdx = batchStart + t;
Expand Down Expand Up @@ -277,6 +302,9 @@ rasterizeGaussiansForward(RasterizeForwardArgs<ScalarType, NUM_CHANNELS, IS_PACK
cuda::std::tie(firstGaussianIdInBlock, lastGaussianIdInBlock) =
commonArgs.tileGaussianRange(cameraId, tileRow, tileCol);

// if (row == 0 && col == 0)
// printf("(%d, %d, %d) = [%d, %d)\n", (int)cameraId, (int)row, (int)col,
// (int)firstGaussianIdInBlock, (int)lastGaussianIdInBlock);
args.volumeRenderTileForward(cameraId,
row,
col,
Expand All @@ -291,9 +319,10 @@ rasterizeGaussiansForward(RasterizeForwardArgs<ScalarType, NUM_CHANNELS, IS_PACK
/// @param tileSize The size of the tile
/// @return The shared memory required in bytes
template <typename ScalarType>
size_t
getSharedMemRequirements(const size_t tileSize) {
return tileSize * tileSize * sizeof(Gaussian2D<ScalarType>);
constexpr size_t
getSharedMemRequirements(const size_t numChannels, const size_t tileSize) {
return tileSize * tileSize *
(sizeof(Gaussian2D<ScalarType>) + numChannels * sizeof(ScalarType));
}

template <typename ScalarType, uint32_t NUM_CHANNELS, bool IS_PACKED>
Expand Down Expand Up @@ -382,7 +411,7 @@ launchRasterizeForwardKernel(
const at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();

// Thread blocks cooperatively cache a tile of Gaussians in shared memory
const uint32_t sharedMem = getSharedMemRequirements<ScalarType>(tileSize);
const uint32_t sharedMem = getSharedMemRequirements<ScalarType>(NUM_CHANNELS, tileSize);

// TODO: an optimization can be done by passing the actual number of
// channels into the kernel functions and avoid necessary global memory
Expand Down Expand Up @@ -495,9 +524,10 @@ launchRasterizeForwardKernels(

uint32_t deviceTileOffset, deviceTileCount;
std::tie(deviceTileOffset, deviceTileCount) = deviceChunk(tileCount, deviceId);
uint32_t cameraOffset = deviceTileOffset / (tileCount / C);
uint32_t cameraOffset = deviceTileOffset / (tileExtentH * tileExtentW);
uint32_t cameraCount =
cuda::ceil_div(deviceTileOffset + deviceTileCount, tileCount / C) - cameraOffset;
cuda::ceil_div(deviceTileOffset + deviceTileCount, (tileExtentH * tileExtentW)) -
cameraOffset;
if (deviceTileCount) {
auto reshapedAlphas =
outAlphas.jdata().view({C, renderWindow.height, renderWindow.width, 1});
Expand Down Expand Up @@ -548,7 +578,7 @@ launchRasterizeForwardKernels(
pixelMap);

// Thread blocks cooperatively cache a tile of Gaussians in shared memory
const uint32_t sharedMem = getSharedMemRequirements<ScalarType>(tileSize);
const uint32_t sharedMem = getSharedMemRequirements<ScalarType>(NUM_CHANNELS, tileSize);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚩 issue: ‏Shouldn't this be NUM_SHARED_CHANNELS? Also what happens if the number of channels is too large to fit all the features in shared memory?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no NUM_SHARED_CHANNELS in the forwards pass (it's just NUM_CHANNELS) because there isn't chunking implemented. So assuming unlimited shared memory, this would be correct as written.

That being said, the lack of chunking is likely why the tests are failing for large feature depths so I'll have to add that.


// TODO: an optimization can be done by passing the actual number of
// channels into the kernel functions and avoid necessary global memory
Expand Down
Loading