@@ -131,6 +131,51 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
131131 }
132132}
133133
134+ template <int block_size>
135+ static __global__ void rms_norm_f32_nc (
136+ const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
137+ const int64_t stride_sample, const float eps) {
138+ const int nrows = gridDim .x ;
139+ const int nchannels = gridDim .y ;
140+
141+ const int row = blockIdx .x ;
142+ const int channel = blockIdx .y ;
143+ const int sample = blockIdx .z ;
144+ const int tid = threadIdx .x ;
145+
146+ x += sample*stride_sample + channel*stride_channel + row*stride_row;
147+ dst += ((sample*nchannels + channel)*nrows + row)*ncols;
148+
149+ float tmp = 0 .0f ; // partial sum for thread in warp
150+
151+ for (int col = tid; col < ncols; col += block_size) {
152+ const float xi = x[col];
153+ tmp += xi * xi;
154+ }
155+
156+ // sum up partial sums
157+ tmp = warp_reduce_sum (tmp);
158+ if constexpr (block_size > WARP_SIZE) {
159+ static_assert (block_size == 1024 , " unexpected block_size" );
160+ __shared__ float s_sum[32 ];
161+ const int warp_id = threadIdx .x / WARP_SIZE;
162+ const int lane_id = threadIdx .x % WARP_SIZE;
163+ if (lane_id == 0 ) {
164+ s_sum[warp_id] = tmp;
165+ }
166+ __syncthreads ();
167+ tmp = s_sum[lane_id];
168+ tmp = warp_reduce_sum (tmp);
169+ }
170+
171+ const float mean = tmp / ncols;
172+ const float scale = rsqrtf (mean + eps);
173+
174+ for (int col = tid; col < ncols; col += block_size) {
175+ dst[col] = scale * x[col];
176+ }
177+ }
178+
134179template <int block_size>
135180static __global__ void fused_rms_norm_f32 (const float * x, const float * y, float * dst, const int ncols, const float eps) {
136181 const int row = blockIdx .x *blockDim .y + threadIdx .y ;
@@ -165,6 +210,51 @@ static __global__ void fused_rms_norm_f32(const float * x, const float * y, floa
165210 }
166211}
167212
213+ template <int block_size>
214+ static __global__ void fused_rms_norm_f32_nc (
215+ const float * x, const float * y, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
216+ const int64_t stride_sample, const float eps) {
217+ const int nrows = gridDim .x ;
218+ const int nchannels = gridDim .y ;
219+
220+ const int row = blockIdx .x ;
221+ const int channel = blockIdx .y ;
222+ const int sample = blockIdx .z ;
223+ const int tid = threadIdx .x ;
224+
225+ x += sample*stride_sample + channel*stride_channel + row*stride_row;
226+ dst += ((sample*nchannels + channel)*nrows + row)*ncols;
227+
228+ float tmp = 0 .0f ; // partial sum for thread in warp
229+
230+ for (int col = tid; col < ncols; col += block_size) {
231+ const float xi = x[col];
232+ tmp += xi * xi;
233+ }
234+
235+ // sum up partial sums
236+ tmp = warp_reduce_sum (tmp);
237+ if constexpr (block_size > WARP_SIZE) {
238+ static_assert (block_size == 1024 , " unexpected block_size" );
239+ __shared__ float s_sum[32 ];
240+ const int warp_id = threadIdx .x / WARP_SIZE;
241+ const int lane_id = threadIdx .x % WARP_SIZE;
242+ if (lane_id == 0 ) {
243+ s_sum[warp_id] = tmp;
244+ }
245+ __syncthreads ();
246+ tmp = s_sum[lane_id];
247+ tmp = warp_reduce_sum (tmp);
248+ }
249+
250+ const float mean = tmp / ncols;
251+ const float scale = rsqrtf (mean + eps);
252+
253+ for (int col = tid; col < ncols; col += block_size) {
254+ dst[col] = scale * y[col] * x[col];
255+ }
256+ }
257+
168258static void norm_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
169259 GGML_ASSERT (ncols % WARP_SIZE == 0 );
170260 if (ncols < 1024 ) {
@@ -197,6 +287,19 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
197287 }
198288}
199289
290+ static void rms_norm_f32_nc_cuda (
291+ const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
292+ const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
293+ const dim3 blocks_num (nrows, nchannels, nsamples);
294+ if (ncols < 1024 ) {
295+ const dim3 block_dims (WARP_SIZE, 1 , 1 );
296+ rms_norm_f32_nc<WARP_SIZE><<<blocks_num, block_dims, 0 , stream>>> (x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
297+ } else {
298+ const dim3 block_dims (1024 , 1 , 1 );
299+ rms_norm_f32_nc<1024 ><<<blocks_num, block_dims, 0 , stream>>> (x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
300+ }
301+ }
302+
200303static void fused_rms_norm_f32_cuda (const float * x, const float * y, float * dst,
201304 const int ncols, const int nrows, const float eps, cudaStream_t stream) {
202305 GGML_ASSERT (ncols % WARP_SIZE == 0 );
@@ -209,6 +312,19 @@ static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * ds
209312 }
210313}
211314
315+ static void fused_rms_norm_f32_nc_cuda (
316+ const float * x, const float * y, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
317+ const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
318+ const dim3 blocks_num (nrows, nchannels, nsamples);
319+ if (ncols < 1024 ) {
320+ const dim3 block_dims (WARP_SIZE, 1 , 1 );
321+ fused_rms_norm_f32_nc<WARP_SIZE><<<blocks_num, block_dims, 0 , stream>>> (x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps);
322+ } else {
323+ const dim3 block_dims (1024 , 1 , 1 );
324+ fused_rms_norm_f32_nc<1024 ><<<blocks_num, block_dims, 0 , stream>>> (x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps);
325+ }
326+ }
327+
212328void ggml_cuda_op_norm (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
213329 const ggml_tensor * src0 = dst->src [0 ];
214330 const float * src0_d = (const float *)src0->data ;
@@ -255,18 +371,24 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
255371 float * dst_d = (float *)dst->data ;
256372 cudaStream_t stream = ctx.stream ();
257373
258- GGML_ASSERT (ggml_is_contiguous (src0));
259-
260374 GGML_ASSERT (src0->type == GGML_TYPE_F32);
261375 GGML_ASSERT ( dst->type == GGML_TYPE_F32);
262376
263- const int64_t ne00 = src0->ne [0 ];
264- const int64_t nrows = ggml_nrows (src0);
265-
266377 float eps;
267378 memcpy (&eps, dst->op_params , sizeof (float ));
268379
269- rms_norm_f32_cuda (src0_d, dst_d, ne00, nrows, eps, stream);
380+ const int64_t ne00 = src0->ne [0 ];
381+ if (ggml_is_contiguous (src0)) {
382+ const int64_t nrows = ggml_nrows (src0);
383+ rms_norm_f32_cuda (src0_d, dst_d, ne00, nrows, eps, stream);
384+ } else {
385+ auto ts0 = ggml_type_size (src0->type );
386+ GGML_ASSERT (src0->nb [0 ] == ts0);
387+ auto s01 = src0->nb [1 ] / ts0;
388+ auto s02 = src0->nb [2 ] / ts0;
389+ auto s03 = src0->nb [3 ] / ts0;
390+ rms_norm_f32_nc_cuda (src0_d, dst_d, ne00, src0->ne [1 ], src0->ne [2 ], src0->ne [3 ], s01, s02, s03, eps, stream);
391+ }
270392}
271393
272394void ggml_cuda_op_fused_rms_norm (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -281,19 +403,26 @@ void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor *
281403 float * dst_d = (float *)dst->data ;
282404 cudaStream_t stream = ctx.stream ();
283405
284- GGML_ASSERT (ggml_is_contiguous (src0));
285-
286406 GGML_ASSERT (src0->type == GGML_TYPE_F32);
287407 GGML_ASSERT (src1->type == GGML_TYPE_F32);
288408 GGML_ASSERT ( dst->type == GGML_TYPE_F32);
289409 GGML_ASSERT (src0->ne [0 ] == src1->ne [0 ]);
290410 GGML_ASSERT (ggml_nrows (src1) == 1 );
291411
292- const int64_t ne00 = src0->ne [0 ];
293- const int64_t nrows = ggml_nrows (src0);
294-
295412 float eps;
296413 memcpy (&eps, dst->op_params , sizeof (float ));
297414
298- fused_rms_norm_f32_cuda (src0_d, src1_d, dst_d, ne00, nrows, eps, stream);
415+ const int64_t ne00 = src0->ne [0 ];
416+
417+ if (ggml_is_contiguous (src0)) {
418+ const int64_t nrows = ggml_nrows (src0);
419+ fused_rms_norm_f32_cuda (src0_d, src1_d, dst_d, ne00, nrows, eps, stream);
420+ } else {
421+ auto ts0 = ggml_type_size (src0->type );
422+ GGML_ASSERT (src0->nb [0 ] == ts0);
423+ auto s01 = src0->nb [1 ] / ts0;
424+ auto s02 = src0->nb [2 ] / ts0;
425+ auto s03 = src0->nb [3 ] / ts0;
426+ fused_rms_norm_f32_nc_cuda (src0_d, src1_d, dst_d, ne00, src0->ne [1 ], src0->ne [2 ], src0->ne [3 ], s01, s02, s03, eps, stream);
427+ }
299428}
0 commit comments