Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions src/fvdb/detail/ops/gsplat/GaussianRasterizeBackward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1149,11 +1149,7 @@ callRasterizeBackwardWithCorrectSharedChannels(
if (numSharedSharedChannelsOverride > 0) {
return callWithSharedChannels(numSharedSharedChannelsOverride);
} else {
cudaDeviceProp deviceProperties;
if (cudaGetDeviceProperties(&deviceProperties, stream.device_index()) != cudaSuccess) {
AT_ERROR("Failed to query device properties");
}
const size_t maxSharedMemory = deviceProperties.sharedMemPerBlockOptin;
const size_t maxSharedMemory = fvdb::detail::getMaxSharedMemory(stream.device_index());

const size_t sharedMemChannelOptions[4] = {NUM_CHANNELS, 64, 32, 16};
for (size_t i = 0; i < 4; ++i) {
Expand Down
37 changes: 37 additions & 0 deletions src/fvdb/detail/utils/cuda/Utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,43 @@ mergeStreams() {
C10_CUDA_CHECK(cudaEventDestroy(mergeEvent));
}

/// @brief Get the maximum shared memory per block optin for a given device
///
/// This function memoizes the result for each device to minimize calls to
/// cudaDeviceGetAttribute.
///
/// This function is thread-safe.
///
/// @param device The device to get the maximum shared memory per block optin for
/// @return The maximum shared memory per block optin in bytes
inline int
getMaxSharedMemory(int device) {
TORCH_CHECK(device >= 0, "Invalid device index: ", device);

static const int deviceCount = c10::cuda::device_count();
TORCH_CHECK(device < deviceCount, "Device index out of range: ", device);

static std::vector<std::atomic<int>> cache(deviceCount);
static std::vector<std::once_flag> initFlags(deviceCount);

if (int cached = cache[device].load(std::memory_order_relaxed); cached > 0) {
return cached;
}

std::call_once(initFlags[device], [device]() {
int maxSharedMemory = 0;
TORCH_CHECK(cudaDeviceGetAttribute(&maxSharedMemory,
cudaDevAttrMaxSharedMemoryPerBlockOptin,
device) == cudaSuccess,
"Failed to get max shared memory per block optin");
cache[device].store(maxSharedMemory, std::memory_order_relaxed);
});

int cached = cache[device].load(std::memory_order_relaxed);
TORCH_CHECK(cached > 0, "Failed to initialize max shared memory for device ", device);
return cached;
};

} // namespace detail

} // namespace fvdb
Expand Down
Loading