diff --git a/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu b/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu index 2f7bfc8bb7..27811cb0ec 100644 --- a/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu +++ b/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu @@ -28,9 +28,15 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v2( const index_t num_indices = indices.size(0); const auto b_t_start = blockIdx.x * blockDim.y + threadIdx.y; - index_t invalid_i = -1, invalid_idx = -1; - int32_t invalid_b_t = -1; - int64_t warning_inc = 0; + #ifdef USE_ROCM + index_t invalid_i = -1, invalid_idx = -1; + int32_t invalid_b_t = -1; + int64_t warning_inc = 0; + __shared__ int64_t block_warning_buffer[kMaxThreads]; + const int linear_tid = threadIdx.z * (blockDim.y * blockDim.x) + + threadIdx.y * blockDim.x + threadIdx.x; + const int active_threads = blockDim.x * blockDim.y * blockDim.z; + #endif // Check the last element if (b_t_start == 0 && threadIdx.x == 0) { @@ -142,9 +148,32 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v2( } } // for b_t +#ifdef USE_ROCM + // Accumulate per-thread warning counts in shared memory and reduce once per block. + block_warning_buffer[linear_tid] = warning_inc; + __syncthreads(); + + // Parallel tree reduction + for (int stride = active_threads / 2; stride > 0; stride >>= 1) { + if (linear_tid < stride) { + block_warning_buffer[linear_tid] += block_warning_buffer[linear_tid + stride]; + } + __syncthreads(); + } + + // Thread 0 has the final sum + if (linear_tid == 0) { + int64_t block_warning_sum = block_warning_buffer[0]; + if (block_warning_sum > 0) { + gpuAtomicAdd(&warning[0], block_warning_sum); + } + } + __syncthreads(); +#else if (warning_inc > 0) { gpuAtomicAdd(&warning[0], warning_inc); } +#endif if (bounds_check_mode == BoundsCheckMode::WARNING && invalid_i != -1 && static_cast(atomicAdd( reinterpret_cast(&warning[0]), 0)) == 0) {