Skip to content

Commit 53ea164

Browse files
authored
Merge branch 'ikawrakow:main' into main
2 parents 1b417f5 + f5571e2 commit 53ea164

File tree

1 file changed

+36
-7
lines changed

1 file changed

+36
-7
lines changed

ggml/src/ggml-cuda/norm.cu

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
119119
s_sum[warp_id] = tmp;
120120
}
121121
__syncthreads();
122-
tmp = s_sum[lane_id];
122+
tmp = lane_id < block_size/WARP_SIZE ? s_sum[lane_id] : 0.0f;
123123
tmp = warp_reduce_sum(tmp);
124124
}
125125

@@ -198,7 +198,7 @@ static __global__ void fused_rms_norm_f32(const float * x, const float * y, floa
198198
s_sum[warp_id] = tmp;
199199
}
200200
__syncthreads();
201-
tmp = s_sum[lane_id];
201+
tmp = lane_id < block_size/WARP_SIZE ? s_sum[lane_id] : 0.0f;
202202
tmp = warp_reduce_sum(tmp);
203203
}
204204

@@ -219,6 +219,7 @@ static __global__ void fused_rms_norm_f32_nc(
219219

220220
const int row = blockIdx.x;
221221
const int channel = blockIdx.y;
222+
//const int channel = blockIdx.y * blockDim.y + threadIdx.y;
222223
const int sample = blockIdx.z;
223224
const int tid = threadIdx.x;
224225

@@ -244,6 +245,11 @@ static __global__ void fused_rms_norm_f32_nc(
244245
}
245246
__syncthreads();
246247
tmp = s_sum[lane_id];
248+
//if constexpr (block_size == 1024) {
249+
// tmp = s_sum[lane_id];
250+
//} else {
251+
// tmp = lane_id < block_size/WARP_SIZE ? s_sum[lane_id] : 0.0f;
252+
//}
247253
tmp = warp_reduce_sum(tmp);
248254
}
249255

@@ -278,9 +284,10 @@ static void group_norm_f32_cuda(const float * x, float * dst, const int num_grou
278284

279285
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
280286
GGML_ASSERT(ncols % WARP_SIZE == 0);
287+
constexpr int kBlockSize = 256;
281288
if (ncols < 1024) {
282-
const dim3 block_dims(WARP_SIZE, 1, 1);
283-
rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
289+
const dim3 block_dims(kBlockSize, 1, 1);
290+
rms_norm_f32<kBlockSize><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
284291
} else {
285292
const dim3 block_dims(1024, 1, 1);
286293
rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
@@ -302,10 +309,22 @@ static void rms_norm_f32_nc_cuda(
302309

303310
static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * dst,
304311
const int ncols, const int nrows, const float eps, cudaStream_t stream) {
312+
constexpr int kBlockSize = 256;
305313
GGML_ASSERT(ncols % WARP_SIZE == 0);
306-
if (ncols < 1024) {
307-
const dim3 block_dims(WARP_SIZE, 1, 1);
308-
fused_rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
314+
if (ncols < kBlockSize) {
315+
switch (ncols) {
316+
case 32: fused_rms_norm_f32< 32><<<nrows, 32, 0, stream>>>(x, y, dst, ncols, eps); break;
317+
case 64: fused_rms_norm_f32< 64><<<nrows, 64, 0, stream>>>(x, y, dst, ncols, eps); break;
318+
case 96: fused_rms_norm_f32< 96><<<nrows, 96, 0, stream>>>(x, y, dst, ncols, eps); break;
319+
case 128: fused_rms_norm_f32<128><<<nrows, 128, 0, stream>>>(x, y, dst, ncols, eps); break;
320+
case 160: fused_rms_norm_f32<160><<<nrows, 160, 0, stream>>>(x, y, dst, ncols, eps); break;
321+
case 192: fused_rms_norm_f32<192><<<nrows, 192, 0, stream>>>(x, y, dst, ncols, eps); break;
322+
default : fused_rms_norm_f32<224><<<nrows, 224, 0, stream>>>(x, y, dst, ncols, eps); break;
323+
}
324+
}
325+
else if (ncols < 1024) {
326+
const dim3 block_dims(kBlockSize, 1, 1);
327+
fused_rms_norm_f32<kBlockSize><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
309328
} else {
310329
const dim3 block_dims(1024, 1, 1);
311330
fused_rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
@@ -319,6 +338,16 @@ static void fused_rms_norm_f32_nc_cuda(
319338
if (ncols < 1024) {
320339
const dim3 block_dims(WARP_SIZE, 1, 1);
321340
fused_rms_norm_f32_nc<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps);
341+
//constexpr int kBlockSize = 256;
342+
343+
//if (nchannels%4 == 0) {
344+
// const dim3 blocks_num(nrows, nchannels/4, nsamples);
345+
// const dim3 block_dims(kBlockSize, 4, 1);
346+
// fused_rms_norm_f32_nc<kBlockSize><<<blocks_num, block_dims, 0, stream>>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps);
347+
//} else {
348+
// const dim3 block_dims(kBlockSize, 1, 1);
349+
// fused_rms_norm_f32_nc<kBlockSize><<<blocks_num, block_dims, 0, stream>>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps);
350+
//}
322351
} else {
323352
const dim3 block_dims(1024, 1, 1);
324353
fused_rms_norm_f32_nc<1024><<<blocks_num, block_dims, 0, stream>>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps);

0 commit comments

Comments
 (0)