Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v2(
index_t invalid_i = -1, invalid_idx = -1;
int32_t invalid_b_t = -1;
int64_t warning_inc = 0;
#ifdef USE_ROCM
__shared__ int64_t block_warning_buffer[kMaxThreads];
const int linear_tid = threadIdx.z * (blockDim.y * blockDim.x) +
threadIdx.y * blockDim.x + threadIdx.x;
const int active_threads = blockDim.x * blockDim.y * blockDim.z;
#endif

// Check the last element
if (b_t_start == 0 && threadIdx.x == 0) {
Expand Down Expand Up @@ -142,9 +148,34 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v2(
}
} // for b_t

#ifdef USE_ROCM
// Accumulate per-thread warning counts in shared memory and reduce once per
// block.
block_warning_buffer[linear_tid] = warning_inc;
__syncthreads();

// Parallel tree reduction
for (int stride = active_threads / 2; stride > 0; stride >>= 1) {
if (linear_tid < stride) {
block_warning_buffer[linear_tid] +=
block_warning_buffer[linear_tid + stride];
}
__syncthreads();
}

// Thread 0 has the final sum
if (linear_tid == 0) {
int64_t block_warning_sum = block_warning_buffer[0];
if (block_warning_sum > 0) {
gpuAtomicAdd(&warning[0], block_warning_sum);
}
}
__syncthreads();
#else
if (warning_inc > 0) {
gpuAtomicAdd(&warning[0], warning_inc);
}
#endif
if (bounds_check_mode == BoundsCheckMode::WARNING && invalid_i != -1 &&
static_cast<int64_t>(atomicAdd(
reinterpret_cast<unsigned long long int*>(&warning[0]), 0)) == 0) {
Expand Down
Loading