@@ -31,6 +31,12 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v2(
3131 index_t invalid_i = -1 , invalid_idx = -1 ;
3232 int32_t invalid_b_t = -1 ;
3333 int64_t warning_inc = 0 ;
34+ #ifdef USE_ROCM
35+ __shared__ int64_t block_warning_buffer[kMaxThreads ];
36+ const int linear_tid = threadIdx .z * (blockDim .y * blockDim .x ) +
37+ threadIdx .y * blockDim .x + threadIdx .x ;
38+ const int active_threads = blockDim .x * blockDim .y * blockDim .z ;
39+ #endif
3440
3541 // Check the last element
3642 if (b_t_start == 0 && threadIdx .x == 0 ) {
@@ -142,9 +148,34 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v2(
142148 }
143149 } // for b_t
144150
151+ #ifdef USE_ROCM
152+ // Accumulate per-thread warning counts in shared memory and reduce once per
153+ // block.
154+ block_warning_buffer[linear_tid] = warning_inc;
155+ __syncthreads ();
156+
157+ // Parallel tree reduction
158+ for (int stride = active_threads / 2 ; stride > 0 ; stride >>= 1 ) {
159+ if (linear_tid < stride) {
160+ block_warning_buffer[linear_tid] +=
161+ block_warning_buffer[linear_tid + stride];
162+ }
163+ __syncthreads ();
164+ }
165+
166+ // Thread 0 has the final sum
167+ if (linear_tid == 0 ) {
168+ int64_t block_warning_sum = block_warning_buffer[0 ];
169+ if (block_warning_sum > 0 ) {
170+ gpuAtomicAdd (&warning[0 ], block_warning_sum);
171+ }
172+ }
173+ __syncthreads ();
174+ #else
145175 if (warning_inc > 0 ) {
146176 gpuAtomicAdd (&warning[0 ], warning_inc);
147177 }
178+ #endif
148179 if (bounds_check_mode == BoundsCheckMode::WARNING && invalid_i != -1 &&
149180 static_cast <int64_t >(atomicAdd (
150181 reinterpret_cast <unsigned long long int *>(&warning[0 ]), 0 )) == 0 ) {
0 commit comments