@@ -105,29 +105,29 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
105105}
106106
107107template <int block_size, bool do_multiply = false , bool do_add = false >
108- static __global__ void rms_norm_f32 (const float * x, float * dst,
108+ static __global__ void rms_norm_f32 (const float * x,
109+ float * dst,
109110 const int ncols,
110111 const int64_t stride_row,
111112 const int64_t stride_channel,
112113 const int64_t stride_sample,
113114 const float eps,
114- const float * mul = nullptr ,
115- const int64_t mul_stride_row = 0 ,
116- const int64_t mul_stride_channel = 0 ,
117- const int64_t mul_stride_sample = 0 ,
118- const int mul_ncols = 0 ,
119- const int mul_nrows = 0 ,
120- const int mul_nchannels = 0 ,
121- const int mul_nsamples = 0 ,
122- const float * add = nullptr ,
123- const int64_t add_stride_row = 0 ,
124- const int64_t add_stride_channel = 0 ,
125- const int64_t add_stride_sample = 0 ,
126- const int add_ncols = 0 ,
127- const int add_nrows = 0 ,
128- const int add_nchannels = 0 ,
129- const int add_nsamples = 0 ) {
130-
115+ const float * mul = nullptr ,
116+ const int64_t mul_stride_row = 0 ,
117+ const int64_t mul_stride_channel = 0 ,
118+ const int64_t mul_stride_sample = 0 ,
119+ const uint3 mul_ncols_packed = make_uint3(0 , 0 , 0 ),
120+ const uint3 mul_nrows_packed = make_uint3(0 , 0 , 0 ),
121+ const uint3 mul_nchannels_packed = make_uint3(0 , 0 , 0 ),
122+ const uint3 mul_nsamples_packed = make_uint3(0 , 0 , 0 ),
123+ const float * add = nullptr,
124+ const int64_t add_stride_row = 0,
125+ const int64_t add_stride_channel = 0,
126+ const int64_t add_stride_sample = 0,
127+ const uint3 add_ncols_packed = make_uint3(0 , 0 , 0 ),
128+ const uint3 add_nrows_packed = make_uint3(0 , 0 , 0 ),
129+ const uint3 add_nchannels_packed = make_uint3(0 , 0 , 0 ),
130+ const uint3 add_nsamples_packed = make_uint3(0 , 0 , 0 )) {
131131 const int nrows = gridDim .x ;
132132 const int nchannels = gridDim .y ;
133133
@@ -142,16 +142,16 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
142142 dst += ((sample*nchannels + channel)*nrows + row)*ncols;
143143
144144 if constexpr (do_multiply) {
145- const int mul_row = row % mul_nrows ;
146- const int mul_channel = channel % mul_nchannels ;
147- const int mul_sample = sample % mul_nsamples ;
148- mul += mul_sample* mul_stride_sample + mul_channel* mul_stride_channel + mul_row* mul_stride_row;
145+ const uint32_t mul_row = fastmodulo ( row, mul_nrows_packed) ;
146+ const uint32_t mul_channel = fastmodulo ( channel, mul_nchannels_packed) ;
147+ const uint32_t mul_sample = fastmodulo ( sample, mul_nsamples_packed) ;
148+ mul += mul_sample * mul_stride_sample + mul_channel * mul_stride_channel + mul_row * mul_stride_row;
149149 }
150150
151151 if constexpr (do_add) {
152- const int add_row = row % add_nrows ;
153- const int add_channel = channel % add_nchannels ;
154- const int add_sample = sample % add_nsamples ;
152+ const int add_row = fastmodulo ( row, add_nrows_packed) ;
153+ const int add_channel = fastmodulo ( channel, add_nchannels_packed) ;
154+ const int add_sample = fastmodulo ( sample, add_nsamples_packed) ;
155155 add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row;
156156 }
157157
@@ -165,15 +165,18 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
165165 // sum up partial sums
166166 tmp = warp_reduce_sum (tmp);
167167 if constexpr (block_size > WARP_SIZE) {
168- static_assert (block_size == 1024 , " unexpected block_size" );
168+ static_assert (( block_size <= 1024 ) && (block_size % 32 == 0 ) , " unexpected block_size" );
169169 __shared__ float s_sum[32 ];
170- const int warp_id = threadIdx . x / WARP_SIZE;
171- const int lane_id = threadIdx . x % WARP_SIZE;
170+ const int warp_id = tid / WARP_SIZE;
171+ const int lane_id = tid % WARP_SIZE;
172172 if (lane_id == 0 ) {
173173 s_sum[warp_id] = tmp;
174174 }
175175 __syncthreads ();
176- tmp = s_sum[lane_id];
176+ tmp = 0 .0f ;
177+ if (lane_id < (block_size / WARP_SIZE)) {
178+ tmp = s_sum[lane_id];
179+ }
177180 tmp = warp_reduce_sum (tmp);
178181 }
179182
@@ -182,12 +185,12 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
182185
183186 for (int col = tid; col < ncols; col += block_size) {
184187 if constexpr (do_multiply && do_add) {
185- const int mul_col = col % mul_ncols ;
186- const int add_col = col % add_ncols ;
187- dst[col] = scale * x[col] * mul[mul_col] + add[add_col];
188+ const int mul_col = fastmodulo ( col, mul_ncols_packed) ;
189+ const int add_col = fastmodulo ( col, add_ncols_packed) ;
190+ dst[col] = scale * x[col] * mul[mul_col] + add[add_col];
188191 } else if constexpr (do_multiply) {
189- const int mul_col = col % mul_ncols ;
190- dst[col] = scale * x[col] * mul[mul_col];
192+ const int mul_col = fastmodulo ( col, mul_ncols_packed) ;
193+ dst[col] = scale * x[col] * mul[mul_col];
191194 } else {
192195 dst[col] = scale * x[col];
193196 }
@@ -354,77 +357,86 @@ static void rms_norm_f32_cuda(
354357 const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
355358 const dim3 blocks_num (nrows, nchannels, nsamples);
356359 if (ncols < 1024 ) {
357- const dim3 block_dims (WARP_SIZE , 1 , 1 );
358- rms_norm_f32<WARP_SIZE , false ><<<blocks_num, block_dims, 0 , stream>>> (x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
360+ const dim3 block_dims (256 , 1 , 1 );
361+ rms_norm_f32<256 , false ><<<blocks_num, block_dims, 0 , stream>>> (x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
359362 } else {
360363 const dim3 block_dims (1024 , 1 , 1 );
361364 rms_norm_f32<1024 , false ><<<blocks_num, block_dims, 0 , stream>>> (x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
362365 }
363366}
364367
365- static void rms_norm_mul_f32_cuda (const float * x,
366- const float * mul,
367- const float * add,
368- float * dst,
369- const int ncols,
370- const int nrows,
371- const int nchannels,
372- const int nsamples,
373- const int64_t stride_row,
374- const int64_t stride_channel,
375- const int64_t stride_sample,
376- const int64_t mul_stride_row,
377- const int64_t mul_stride_channel,
378- const int64_t mul_stride_sample,
379- const int mul_ncols,
380- const int mul_nrows,
381- const int mul_nchannels,
382- const int mul_nsamples,
383- const int64_t add_stride_row,
384- const int64_t add_stride_channel,
385- const int64_t add_stride_sample,
386- const int add_ncols,
387- const int add_nrows,
388- const int add_nchannels,
389- const int add_nsamples,
390- const float eps,
391- cudaStream_t stream) {
368+ static void rms_norm_mul_f32_cuda (const float * x,
369+ const float * mul,
370+ const float * add,
371+ float * dst,
372+ const int ncols,
373+ const int nrows,
374+ const int nchannels,
375+ const int nsamples,
376+ const int64_t stride_row,
377+ const int64_t stride_channel,
378+ const int64_t stride_sample,
379+ const int64_t mul_stride_row,
380+ const int64_t mul_stride_channel,
381+ const int64_t mul_stride_sample,
382+ const uint32_t mul_ncols,
383+ const uint32_t mul_nrows,
384+ const uint32_t mul_nchannels,
385+ const uint32_t mul_nsamples,
386+ const int64_t add_stride_row,
387+ const int64_t add_stride_channel,
388+ const int64_t add_stride_sample,
389+ const uint32_t add_ncols,
390+ const uint32_t add_nrows,
391+ const uint32_t add_nchannels,
392+ const uint32_t add_nsamples,
393+ const float eps,
394+ cudaStream_t stream) {
392395 const dim3 blocks_num (nrows, nchannels, nsamples);
393396 if (mul == nullptr ) {
394397 rms_norm_f32_cuda (x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
395398 return ;
396399 }
397400 if (add == nullptr ) {
401+ const uint3 mul_ncols_packed = init_fastdiv_values (mul_ncols);
402+ const uint3 mul_nrows_packed = init_fastdiv_values (mul_nrows);
403+ const uint3 mul_nchannels_packed = init_fastdiv_values (mul_nchannels);
404+ const uint3 mul_nsamples_packed = init_fastdiv_values (mul_nsamples);
398405 if (ncols < 1024 ) {
399- const dim3 block_dims (WARP_SIZE, 1 , 1 );
400- rms_norm_f32<WARP_SIZE, true ><<<blocks_num, block_dims, 0 , stream>>> (x, dst,
401- ncols, stride_row, stride_channel, stride_sample, eps,
402- mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
403- mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
406+ const dim3 block_dims (256 , 1 , 1 );
407+ rms_norm_f32<256 , true ><<<blocks_num, block_dims, 0 , stream>>> (
408+ x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
409+ mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
404410 } else {
405411 const dim3 block_dims (1024 , 1 , 1 );
406- rms_norm_f32<1024 , true ><<<blocks_num, block_dims, 0 , stream>>> (x, dst,
407- ncols, stride_row, stride_channel, stride_sample, eps,
408- mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
409- mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
412+ rms_norm_f32<1024 , true ><<<blocks_num, block_dims, 0 , stream>>> (
413+ x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
414+ mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
410415 }
411416 } else {
417+ const uint3 mul_ncols_packed = init_fastdiv_values (mul_ncols);
418+ const uint3 mul_nrows_packed = init_fastdiv_values (mul_nrows);
419+ const uint3 mul_nchannels_packed = init_fastdiv_values (mul_nchannels);
420+ const uint3 mul_nsamples_packed = init_fastdiv_values (mul_nsamples);
421+
422+ const uint3 add_ncols_packed = init_fastdiv_values (add_ncols);
423+ const uint3 add_nrows_packed = init_fastdiv_values (add_nrows);
424+ const uint3 add_nchannels_packed = init_fastdiv_values (add_nchannels);
425+ const uint3 add_nsamples_packed = init_fastdiv_values (add_nsamples);
412426 if (ncols < 1024 ) {
413- const dim3 block_dims (WARP_SIZE, 1 , 1 );
414- rms_norm_f32<WARP_SIZE, true , true ><<<blocks_num, block_dims, 0 , stream>>> (x, dst,
415- ncols, stride_row, stride_channel, stride_sample, eps,
416- mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
417- mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
418- add, add_stride_row, add_stride_channel, add_stride_sample,
419- add_ncols, add_nrows, add_nchannels, add_nsamples);
427+ const dim3 block_dims (256 , 1 , 1 );
428+ rms_norm_f32<256 , true , true ><<<blocks_num, block_dims, 0 , stream>>> (
429+ x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
430+ mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
431+ add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
432+ add_nchannels_packed, add_nsamples_packed);
420433 } else {
421434 const dim3 block_dims (1024 , 1 , 1 );
422- rms_norm_f32<1024 , true , true ><<<blocks_num, block_dims, 0 , stream>>> (x, dst,
423- ncols, stride_row, stride_channel, stride_sample, eps,
424- mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
425- mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
426- add, add_stride_row, add_stride_channel, add_stride_sample,
427- add_ncols, add_nrows, add_nchannels, add_nsamples);
435+ rms_norm_f32<1024 , true , true ><<<blocks_num, block_dims, 0 , stream>>> (
436+ x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
437+ mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
438+ add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
439+ add_nchannels_packed, add_nsamples_packed);
428440 }
429441 }
430442}
0 commit comments