Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions src/fvdb/detail/ops/gsplat/GaussianRasterizeBackward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1149,11 +1149,7 @@ callRasterizeBackwardWithCorrectSharedChannels(
if (numSharedSharedChannelsOverride > 0) {
return callWithSharedChannels(numSharedSharedChannelsOverride);
} else {
cudaDeviceProp deviceProperties;
if (cudaGetDeviceProperties(&deviceProperties, stream.device_index()) != cudaSuccess) {
AT_ERROR("Failed to query device properties");
}
const size_t maxSharedMemory = deviceProperties.sharedMemPerBlockOptin;
const size_t maxSharedMemory = fvdb::detail::getMaxSharedMemory(stream.device_index());

const size_t sharedMemChannelOptions[4] = {NUM_CHANNELS, 64, 32, 16};
for (size_t i = 0; i < 4; ++i) {
Expand Down
38 changes: 38 additions & 0 deletions src/fvdb/detail/utils/cuda/Utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,44 @@ mergeStreams() {
C10_CUDA_CHECK(cudaEventDestroy(mergeEvent));
}

/// @brief Get the maximum shared memory per block optin for a given device
///
/// This function memoizes the result for each device to minimize calls to
/// cudaDeviceGetAttribute.
///
/// This function is thread-safe.
///
/// @param device The device to get the maximum shared memory per block optin for
/// @return The maximum shared memory per block optin in bytes
inline int
getMaxSharedMemory(int device) {
TORCH_CHECK(device >= 0, "Invalid device index: ", device);

static const int deviceCount = c10::cuda::device_count();
TORCH_CHECK(device < deviceCount, "Device index out of range: ", device);

static std::vector<std::atomic<int>> cache(deviceCount);
static std::vector<std::once_flag> initFlags(deviceCount);

int cached = cache[device].load(std::memory_order_relaxed);
if (cached > 0) {
return cached;
}

std::call_once(initFlags[device], [device]() {
int maxSharedMemory = 0;
TORCH_CHECK(cudaDeviceGetAttribute(&maxSharedMemory,
cudaDevAttrMaxSharedMemoryPerBlockOptin,
device) == cudaSuccess,
"Failed to get max shared memory per block");
cache[device].store(maxSharedMemory, std::memory_order_relaxed);
});

cached = cache[device].load(std::memory_order_relaxed);
TORCH_CHECK(cached > 0, "Failed to initialize max shared memory for device ", device);
return cached;
};

} // namespace detail

} // namespace fvdb
Expand Down
Loading