Skip to content

Commit caf9759

Browse files
ikawrakowIwan Kawrakow
andauthored
Fuse add + fused_rms_norm (CUDA) (ikawrakow#852)
* Combine all calls to llm_build_norm to a single line so more easily check what kind of arguments are being passed by simply using grep. * Combine add + fused_rms_norm For many models this happens at each layer: the result of the layer is added to the ayer input, which then becomes the input to the next layer, which then is typically normalized via fused_rms_norm. --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 9223146 commit caf9759

File tree

4 files changed

+258
-531
lines changed

4 files changed

+258
-531
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3129,7 +3129,17 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
31293129
ggml_cuda_dup(ctx, dst);
31303130
break;
31313131
case GGML_OP_ADD:
3132-
ggml_cuda_op_add(ctx, dst);
3132+
if (i + 1 < cgraph->n_nodes &&
3133+
cgraph->nodes[i+1]->op == GGML_OP_FUSED_RMS_NORM &&
3134+
ggml_is_contiguous(dst->src[0]) &&
3135+
ggml_is_contiguous(dst->src[1]) &&
3136+
ggml_are_same_shape(dst->src[0], dst->src[1])) {
3137+
ggml_cuda_op_fused_add_rms_norm(ctx, dst, cgraph->nodes[i+1]);
3138+
++i;
3139+
} else {
3140+
ggml_cuda_op_add(ctx, dst);
3141+
}
3142+
//ggml_cuda_op_add(ctx, dst);
31333143
break;
31343144
case GGML_OP_ADD_ID:
31353145
ggml_cuda_op_add_id(ctx, dst);

ggml/src/ggml-cuda/norm.cu

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,3 +455,86 @@ void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor *
455455
fused_rms_norm_f32_nc_cuda(src0_d, src1_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream);
456456
}
457457
}
458+
459+
template <int block_size>
460+
static __global__ void fused_add_rms_norm_f32(const float * a, const float * b, const float * c,
461+
float * dst_add, float * dst, const int ncols, const float eps) {
462+
const int row = blockIdx.x*blockDim.y + threadIdx.y;
463+
const int tid = threadIdx.x;
464+
465+
float tmp = 0.0f; // partial sum for thread in warp
466+
467+
for (int col = tid; col < ncols; col += block_size) {
468+
const float xi = a[row*ncols + col] + b[row*ncols + col];
469+
tmp += xi * xi;
470+
dst_add[row*ncols + col] = xi;
471+
}
472+
473+
// sum up partial sums
474+
tmp = warp_reduce_sum(tmp);
475+
if (block_size > WARP_SIZE) {
476+
__shared__ float s_sum[32];
477+
int warp_id = threadIdx.x / WARP_SIZE;
478+
int lane_id = threadIdx.x % WARP_SIZE;
479+
if (lane_id == 0) {
480+
s_sum[warp_id] = tmp;
481+
}
482+
__syncthreads();
483+
tmp = lane_id < block_size/WARP_SIZE ? s_sum[lane_id] : 0.0f;
484+
tmp = warp_reduce_sum(tmp);
485+
}
486+
487+
const float mean = tmp / ncols;
488+
const float scale = rsqrtf(mean + eps);
489+
490+
for (int col = tid; col < ncols; col += block_size) {
491+
dst[row*ncols + col] = scale * c[col] * dst_add[row*ncols + col];
492+
}
493+
}
494+
495+
496+
static void fused_add_rms_norm_f32_cuda(const float * a, const float * b, const float * c, float * dst_add, float * dst,
497+
const int ncols, const int nrows, const float eps, cudaStream_t stream) {
498+
GGML_ASSERT(ncols % WARP_SIZE == 0);
499+
if (ncols < 1024) {
500+
const dim3 block_dims(256, 1, 1);
501+
fused_add_rms_norm_f32<256><<<nrows, block_dims, 0, stream>>>(a, b, c, dst_add, dst, ncols, eps);
502+
} else {
503+
const dim3 block_dims(1024, 1, 1);
504+
fused_add_rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(a, b, c, dst_add, dst, ncols, eps);
505+
}
506+
}
507+
508+
void ggml_cuda_op_fused_add_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * add, ggml_tensor * dst) {
509+
510+
const ggml_tensor * src0 = dst->src[0];
511+
const ggml_tensor * src1 = dst->src[1];
512+
//const float * src0_d = (const float *)src0->data;
513+
const float * src1_d = (const float *)src1->data;
514+
float * dst_d = (float *)dst->data;
515+
cudaStream_t stream = ctx.stream();
516+
517+
GGML_ASSERT(add->data == src0->data);
518+
GGML_ASSERT(ggml_is_contiguous(src0));
519+
GGML_ASSERT(ggml_is_contiguous(add->src[0]));
520+
GGML_ASSERT(ggml_is_contiguous(add->src[1]));
521+
GGML_ASSERT(ggml_are_same_shape(add->src[0], add->src[1]));
522+
GGML_ASSERT(ggml_are_same_shape(add->src[0], src0));
523+
GGML_ASSERT(add->src[0]->type == GGML_TYPE_F32);
524+
GGML_ASSERT(add->src[1]->type == GGML_TYPE_F32);
525+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
526+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
527+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
528+
GGML_ASSERT(src0->ne[0] == src1->ne[0]);
529+
GGML_ASSERT(ggml_nrows(src1) == 1);
530+
531+
float eps;
532+
memcpy(&eps, dst->op_params, sizeof(float));
533+
534+
const int64_t ne00 = src0->ne[0];
535+
536+
const int64_t nrows = ggml_nrows(src0);
537+
fused_add_rms_norm_f32_cuda((const float *)add->src[0]->data, (const float *)add->src[1]->data,
538+
src1_d, (float *)add->data, dst_d, ne00, nrows, eps, stream);
539+
}
540+

ggml/src/ggml-cuda/norm.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,5 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
77
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
88

99
void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
10+
11+
void ggml_cuda_op_fused_add_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * add, ggml_tensor * dst);

0 commit comments

Comments
 (0)