Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,15 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v2(

const index_t num_indices = indices.size(0);
const auto b_t_start = blockIdx.x * blockDim.y + threadIdx.y;
index_t invalid_i = -1, invalid_idx = -1;
int32_t invalid_b_t = -1;
int64_t warning_inc = 0;
#ifdef USE_ROCM
index_t invalid_i = -1, invalid_idx = -1;
int32_t invalid_b_t = -1;
int64_t warning_inc = 0;
__shared__ int64_t block_warning_buffer[kMaxThreads];
const int linear_tid = threadIdx.z * (blockDim.y * blockDim.x) +
threadIdx.y * blockDim.x + threadIdx.x;
const int active_threads = blockDim.x * blockDim.y * blockDim.z;
#endif

// Check the last element
if (b_t_start == 0 && threadIdx.x == 0) {
Expand Down Expand Up @@ -142,9 +148,32 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v2(
}
} // for b_t

#ifdef USE_ROCM
// Accumulate per-thread warning counts in shared memory and reduce once per block.
Copy link
Contributor

@q10 q10 Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you guard the changes with #ifdef USE_ROCM for now, at least until we can ascertain that this adversely does not impact NVIDIA performance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guarded in 2aaaee3

block_warning_buffer[linear_tid] = warning_inc;
__syncthreads();

// Parallel tree reduction
for (int stride = active_threads / 2; stride > 0; stride >>= 1) {
if (linear_tid < stride) {
block_warning_buffer[linear_tid] += block_warning_buffer[linear_tid + stride];
}
__syncthreads();
}

// Thread 0 has the final sum
if (linear_tid == 0) {
int64_t block_warning_sum = block_warning_buffer[0];
if (block_warning_sum > 0) {
gpuAtomicAdd(&warning[0], block_warning_sum);
}
}
__syncthreads();
#else
if (warning_inc > 0) {
gpuAtomicAdd(&warning[0], warning_inc);
}
#endif
if (bounds_check_mode == BoundsCheckMode::WARNING && invalid_i != -1 &&
static_cast<int64_t>(atomicAdd(
reinterpret_cast<unsigned long long int*>(&warning[0]), 0)) == 0) {
Expand Down
Loading