@@ -119,7 +119,7 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
119119 s_sum[warp_id] = tmp;
120120 }
121121 __syncthreads ();
122- tmp = s_sum[lane_id];
122+ tmp = lane_id < block_size/WARP_SIZE ? s_sum[lane_id] : 0 . 0f ;
123123 tmp = warp_reduce_sum (tmp);
124124 }
125125
@@ -198,7 +198,7 @@ static __global__ void fused_rms_norm_f32(const float * x, const float * y, floa
198198 s_sum[warp_id] = tmp;
199199 }
200200 __syncthreads ();
201- tmp = s_sum[lane_id];
201+ tmp = lane_id < block_size/WARP_SIZE ? s_sum[lane_id] : 0 . 0f ;
202202 tmp = warp_reduce_sum (tmp);
203203 }
204204
@@ -219,6 +219,7 @@ static __global__ void fused_rms_norm_f32_nc(
219219
220220 const int row = blockIdx .x ;
221221 const int channel = blockIdx .y ;
222+ // const int channel = blockIdx.y * blockDim.y + threadIdx.y;
222223 const int sample = blockIdx .z ;
223224 const int tid = threadIdx .x ;
224225
@@ -244,6 +245,11 @@ static __global__ void fused_rms_norm_f32_nc(
244245 }
245246 __syncthreads ();
246247 tmp = s_sum[lane_id];
248+ // if constexpr (block_size == 1024) {
249+ // tmp = s_sum[lane_id];
250+ // } else {
251+ // tmp = lane_id < block_size/WARP_SIZE ? s_sum[lane_id] : 0.0f;
252+ // }
247253 tmp = warp_reduce_sum (tmp);
248254 }
249255
@@ -278,9 +284,10 @@ static void group_norm_f32_cuda(const float * x, float * dst, const int num_grou
278284
279285static void rms_norm_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
280286 GGML_ASSERT (ncols % WARP_SIZE == 0 );
287+ constexpr int kBlockSize = 256 ;
281288 if (ncols < 1024 ) {
282- const dim3 block_dims (WARP_SIZE , 1 , 1 );
283- rms_norm_f32<WARP_SIZE ><<<nrows, block_dims, 0 , stream>>> (x, dst, ncols, eps);
289+ const dim3 block_dims (kBlockSize , 1 , 1 );
290+ rms_norm_f32<kBlockSize ><<<nrows, block_dims, 0 , stream>>> (x, dst, ncols, eps);
284291 } else {
285292 const dim3 block_dims (1024 , 1 , 1 );
286293 rms_norm_f32<1024 ><<<nrows, block_dims, 0 , stream>>> (x, dst, ncols, eps);
@@ -302,10 +309,22 @@ static void rms_norm_f32_nc_cuda(
302309
303310static void fused_rms_norm_f32_cuda (const float * x, const float * y, float * dst,
304311 const int ncols, const int nrows, const float eps, cudaStream_t stream) {
312+ constexpr int kBlockSize = 256 ;
305313 GGML_ASSERT (ncols % WARP_SIZE == 0 );
306- if (ncols < 1024 ) {
307- const dim3 block_dims (WARP_SIZE, 1 , 1 );
308- fused_rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0 , stream>>> (x, y, dst, ncols, eps);
314+ if (ncols < kBlockSize ) {
315+ switch (ncols) {
316+ case 32 : fused_rms_norm_f32< 32 ><<<nrows, 32 , 0 , stream>>> (x, y, dst, ncols, eps); break ;
317+ case 64 : fused_rms_norm_f32< 64 ><<<nrows, 64 , 0 , stream>>> (x, y, dst, ncols, eps); break ;
318+ case 96 : fused_rms_norm_f32< 96 ><<<nrows, 96 , 0 , stream>>> (x, y, dst, ncols, eps); break ;
319+ case 128 : fused_rms_norm_f32<128 ><<<nrows, 128 , 0 , stream>>> (x, y, dst, ncols, eps); break ;
320+ case 160 : fused_rms_norm_f32<160 ><<<nrows, 160 , 0 , stream>>> (x, y, dst, ncols, eps); break ;
321+ case 192 : fused_rms_norm_f32<192 ><<<nrows, 192 , 0 , stream>>> (x, y, dst, ncols, eps); break ;
322+ default : fused_rms_norm_f32<224 ><<<nrows, 224 , 0 , stream>>> (x, y, dst, ncols, eps); break ;
323+ }
324+ }
325+ else if (ncols < 1024 ) {
326+ const dim3 block_dims (kBlockSize , 1 , 1 );
327+ fused_rms_norm_f32<kBlockSize ><<<nrows, block_dims, 0 , stream>>> (x, y, dst, ncols, eps);
309328 } else {
310329 const dim3 block_dims (1024 , 1 , 1 );
311330 fused_rms_norm_f32<1024 ><<<nrows, block_dims, 0 , stream>>> (x, y, dst, ncols, eps);
@@ -319,6 +338,16 @@ static void fused_rms_norm_f32_nc_cuda(
319338 if (ncols < 1024 ) {
320339 const dim3 block_dims (WARP_SIZE, 1 , 1 );
321340 fused_rms_norm_f32_nc<WARP_SIZE><<<blocks_num, block_dims, 0 , stream>>> (x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps);
341+ // constexpr int kBlockSize = 256;
342+
343+ // if (nchannels%4 == 0) {
344+ // const dim3 blocks_num(nrows, nchannels/4, nsamples);
345+ // const dim3 block_dims(kBlockSize, 4, 1);
346+ // fused_rms_norm_f32_nc<kBlockSize><<<blocks_num, block_dims, 0, stream>>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps);
347+ // } else {
348+ // const dim3 block_dims(kBlockSize, 1, 1);
349+ // fused_rms_norm_f32_nc<kBlockSize><<<blocks_num, block_dims, 0, stream>>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps);
350+ // }
322351 } else {
323352 const dim3 block_dims (1024 , 1 , 1 );
324353 fused_rms_norm_f32_nc<1024 ><<<blocks_num, block_dims, 0 , stream>>> (x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps);
0 commit comments