Skip to content

Commit bcc6c77

Browse files
committed
Support more block_size values in rms_norm_f32
This makes us more flexible in selecting the optimal threads w.r.t paralellizing across a col vs. launch-overheads of threads and mio throttles
1 parent b2e9831 commit bcc6c77

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

ggml/src/ggml-cuda/norm.cu

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -181,15 +181,18 @@ static __global__ void rms_norm_f32(const float * x,
181181
// sum up partial sums
182182
tmp = warp_reduce_sum(tmp);
183183
if constexpr (block_size > WARP_SIZE) {
184-
static_assert(block_size == 1024, "unexpected block_size");
184+
static_assert((block_size <= 1024) && (block_size % 32 == 0), "unexpected block_size");
185185
__shared__ float s_sum[32];
186-
const int warp_id = threadIdx.x / WARP_SIZE;
187-
const int lane_id = threadIdx.x % WARP_SIZE;
186+
const int warp_id = tid / WARP_SIZE;
187+
const int lane_id = tid % WARP_SIZE;
188188
if (lane_id == 0) {
189189
s_sum[warp_id] = tmp;
190190
}
191191
__syncthreads();
192-
tmp = s_sum[lane_id];
192+
tmp = 0.0f;
193+
if (lane_id < (block_size / WARP_SIZE)) {
194+
tmp = s_sum[lane_id];
195+
}
193196
tmp = warp_reduce_sum(tmp);
194197
}
195198

@@ -370,8 +373,8 @@ static void rms_norm_f32_cuda(
370373
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
371374
const dim3 blocks_num(nrows, nchannels, nsamples);
372375
if (ncols < 1024) {
373-
const dim3 block_dims(WARP_SIZE, 1, 1);
374-
rms_norm_f32<WARP_SIZE, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
376+
const dim3 block_dims(256, 1, 1);
377+
rms_norm_f32<256, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
375378
} else {
376379
const dim3 block_dims(1024, 1, 1);
377380
rms_norm_f32<1024, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
@@ -420,8 +423,8 @@ static void rms_norm_mul_f32_cuda(const float * x,
420423
uint32_t mp_mul_samples, L_mul_samples;
421424
init_fastdiv_values(mul_nsamples, mp_mul_samples, L_mul_samples);
422425
if (ncols < 1024) {
423-
const dim3 block_dims(WARP_SIZE, 1, 1);
424-
rms_norm_f32<WARP_SIZE, true><<<blocks_num, block_dims, 0, stream>>>(x,
426+
const dim3 block_dims(256, 1, 1);
427+
rms_norm_f32<256, true><<<blocks_num, block_dims, 0, stream>>>(x,
425428
dst,
426429
ncols,
427430
stride_row,
@@ -489,8 +492,8 @@ static void rms_norm_mul_f32_cuda(const float * x,
489492
uint32_t mp_add_samples, L_add_samples;
490493
init_fastdiv_values(add_nsamples, mp_add_samples, L_add_samples);
491494
if (ncols < 1024) {
492-
const dim3 block_dims(WARP_SIZE, 1, 1);
493-
rms_norm_f32<WARP_SIZE, true, true><<<blocks_num, block_dims, 0, stream>>>(x,
495+
const dim3 block_dims(256, 1, 1);
496+
rms_norm_f32<256, true, true><<<blocks_num, block_dims, 0, stream>>>(x,
494497
dst,
495498
ncols,
496499
stride_row,

0 commit comments

Comments
 (0)