Skip to content

Commit 85949ce

Browse files
liligwufacebook-github-bot
authored andcommitted
minimize gpuAtomicAdd overhead in bounds_check_indices_kernel_v2 (#5124)
Summary: X-link: facebookresearch/FBGEMM#2168 Differential Revision: D87008101 Pulled By: q10
1 parent 903002a commit 85949ce

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v2(
3131
index_t invalid_i = -1, invalid_idx = -1;
3232
int32_t invalid_b_t = -1;
3333
int64_t warning_inc = 0;
34+
#ifdef USE_ROCM
35+
__shared__ int64_t block_warning_buffer[kMaxThreads];
36+
const int linear_tid = threadIdx.z * (blockDim.y * blockDim.x) +
37+
threadIdx.y * blockDim.x + threadIdx.x;
38+
const int active_threads = blockDim.x * blockDim.y * blockDim.z;
39+
#endif
3440

3541
// Check the last element
3642
if (b_t_start == 0 && threadIdx.x == 0) {
@@ -142,9 +148,34 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v2(
142148
}
143149
} // for b_t
144150

151+
#ifdef USE_ROCM
152+
// Accumulate per-thread warning counts in shared memory and reduce once per
153+
// block.
154+
block_warning_buffer[linear_tid] = warning_inc;
155+
__syncthreads();
156+
157+
// Parallel tree reduction
158+
for (int stride = active_threads / 2; stride > 0; stride >>= 1) {
159+
if (linear_tid < stride) {
160+
block_warning_buffer[linear_tid] +=
161+
block_warning_buffer[linear_tid + stride];
162+
}
163+
__syncthreads();
164+
}
165+
166+
// Thread 0 has the final sum
167+
if (linear_tid == 0) {
168+
int64_t block_warning_sum = block_warning_buffer[0];
169+
if (block_warning_sum > 0) {
170+
gpuAtomicAdd(&warning[0], block_warning_sum);
171+
}
172+
}
173+
__syncthreads();
174+
#else
145175
if (warning_inc > 0) {
146176
gpuAtomicAdd(&warning[0], warning_inc);
147177
}
178+
#endif
148179
if (bounds_check_mode == BoundsCheckMode::WARNING && invalid_i != -1 &&
149180
static_cast<int64_t>(atomicAdd(
150181
reinterpret_cast<unsigned long long int*>(&warning[0]), 0)) == 0) {

test/EmbeddingSpMDMTest.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ TEST_P(EmbeddingSpMDMTest, basicTest) {
144144
int embedding_dim = input[2];
145145
int average_len = input[3];
146146
int output_stride = use_output_input_stride ? embedding_dim * 2 + 3 : -1;
147-
int input_stride = use_output_input_stride ? embedding_dim * 2 + 3 : -1;
147+
int input_stride =
148+
use_output_input_stride ? (embedding_dim * 2 + 3) * 10 : -1;
148149

149150
// Create embedding table
150151
vector<float> embedding_table(

0 commit comments

Comments
 (0)