Skip to content

Commit 685d367

Browse files
committed
sync with pytorch#5124
1 parent 683bfb5 commit 685d367

File tree

1 file changed

+23
-2
lines changed

1 file changed

+23
-2
lines changed

fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v2(
3131
index_t invalid_i = -1, invalid_idx = -1;
3232
int32_t invalid_b_t = -1;
3333
int64_t warning_inc = 0;
34+
__shared__ int64_t block_warning_buffer[kMaxThreads];
35+
const int linear_tid = threadIdx.z * (blockDim.y * blockDim.x) +
36+
threadIdx.y * blockDim.x + threadIdx.x;
37+
const int active_threads = blockDim.x * blockDim.y * blockDim.z;
3438

3539
// Check the last element
3640
if (b_t_start == 0 && threadIdx.x == 0) {
@@ -142,9 +146,26 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v2(
142146
}
143147
} // for b_t
144148

145-
if (warning_inc > 0) {
146-
gpuAtomicAdd(&warning[0], warning_inc);
149+
// Accumulate per-thread warning counts in shared memory and reduce once per block.
150+
block_warning_buffer[linear_tid] = warning_inc;
151+
__syncthreads();
152+
153+
// Parallel tree reduction
154+
for (int stride = active_threads / 2; stride > 0; stride >>= 1) {
155+
if (linear_tid < stride) {
156+
block_warning_buffer[linear_tid] += block_warning_buffer[linear_tid + stride];
157+
}
158+
__syncthreads();
159+
}
160+
161+
// Thread 0 has the final sum
162+
if (linear_tid == 0) {
163+
int64_t block_warning_sum = block_warning_buffer[0];
164+
if (block_warning_sum > 0) {
165+
gpuAtomicAdd(&warning[0], block_warning_sum);
166+
}
147167
}
168+
__syncthreads();
148169
if (bounds_check_mode == BoundsCheckMode::WARNING && invalid_i != -1 &&
149170
static_cast<int64_t>(atomicAdd(
150171
reinterpret_cast<unsigned long long int*>(&warning[0]), 0)) == 0) {

0 commit comments

Comments
 (0)