Skip to content

Commit 43dce0f

Browse files
ikawrakowIwan Kawrakow
authored andcommitted
Adding fused rms_norm (#42)
* Fused rms_norm: works on the CPU * Fused rms_norm WIP * Fused rms_norm WIP * Fused rms_norm WIP * Fused rms_norm WIP * Fused rms_norm WIP --------- Co-Authored-By: Iwan Kawrakow <[email protected]>
1 parent 3e24d02 commit 43dce0f

File tree

8 files changed

+327
-3
lines changed

8 files changed

+327
-3
lines changed

ggml/include/ggml.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,7 @@ extern "C" {
475475
GGML_OP_RMS_NORM,
476476
GGML_OP_RMS_NORM_BACK,
477477
GGML_OP_GROUP_NORM,
478+
GGML_OP_FUSED_RMS_NORM,
478479

479480
GGML_OP_MUL_MAT,
480481
GGML_OP_MUL_MAT_ID,
@@ -1176,6 +1177,18 @@ extern "C" {
11761177
struct ggml_tensor * a,
11771178
float eps);
11781179

1180+
GGML_API struct ggml_tensor * ggml_fused_rms_norm(
1181+
struct ggml_context * ctx,
1182+
struct ggml_tensor * a,
1183+
struct ggml_tensor * b,
1184+
float eps);
1185+
1186+
GGML_API struct ggml_tensor * ggml_fused_rms_norm_inplace(
1187+
struct ggml_context * ctx,
1188+
struct ggml_tensor * a,
1189+
struct ggml_tensor * b,
1190+
float eps);
1191+
11791192
// group normalize along ne0*ne1*n_groups
11801193
// used in stable-diffusion
11811194
GGML_API struct ggml_tensor * ggml_group_norm(

ggml/src/ggml-cuda.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2260,6 +2260,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
22602260
case GGML_OP_RMS_NORM:
22612261
ggml_cuda_op_rms_norm(ctx, dst);
22622262
break;
2263+
case GGML_OP_FUSED_RMS_NORM:
2264+
ggml_cuda_op_fused_rms_norm(ctx, dst);
2265+
break;
22632266
case GGML_OP_MUL_MAT:
22642267
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
22652268
GGML_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]);
@@ -3139,6 +3142,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
31393142
case GGML_OP_MUL:
31403143
case GGML_OP_DIV:
31413144
case GGML_OP_RMS_NORM:
3145+
case GGML_OP_FUSED_RMS_NORM:
31423146
case GGML_OP_SCALE:
31433147
case GGML_OP_SQR:
31443148
case GGML_OP_SQRT:

ggml/src/ggml-cuda/norm.cu

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,40 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
131131
}
132132
}
133133

134+
template <int block_size>
135+
static __global__ void fused_rms_norm_f32(const float * x, const float * y, float * dst, const int ncols, const float eps) {
136+
const int row = blockIdx.x*blockDim.y + threadIdx.y;
137+
const int tid = threadIdx.x;
138+
139+
float tmp = 0.0f; // partial sum for thread in warp
140+
141+
for (int col = tid; col < ncols; col += block_size) {
142+
const float xi = x[row*ncols + col];
143+
tmp += xi * xi;
144+
}
145+
146+
// sum up partial sums
147+
tmp = warp_reduce_sum(tmp);
148+
if (block_size > WARP_SIZE) {
149+
__shared__ float s_sum[32];
150+
int warp_id = threadIdx.x / WARP_SIZE;
151+
int lane_id = threadIdx.x % WARP_SIZE;
152+
if (lane_id == 0) {
153+
s_sum[warp_id] = tmp;
154+
}
155+
__syncthreads();
156+
tmp = s_sum[lane_id];
157+
tmp = warp_reduce_sum(tmp);
158+
}
159+
160+
const float mean = tmp / ncols;
161+
const float scale = rsqrtf(mean + eps);
162+
163+
for (int col = tid; col < ncols; col += block_size) {
164+
dst[row*ncols + col] = scale * y[col] * x[row*ncols + col];
165+
}
166+
}
167+
134168
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
135169
GGML_ASSERT(ncols % WARP_SIZE == 0);
136170
if (ncols < 1024) {
@@ -163,6 +197,18 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
163197
}
164198
}
165199

200+
static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * dst,
201+
const int ncols, const int nrows, const float eps, cudaStream_t stream) {
202+
GGML_ASSERT(ncols % WARP_SIZE == 0);
203+
if (ncols < 1024) {
204+
const dim3 block_dims(WARP_SIZE, 1, 1);
205+
fused_rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
206+
} else {
207+
const dim3 block_dims(1024, 1, 1);
208+
fused_rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
209+
}
210+
}
211+
166212
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
167213
const ggml_tensor * src0 = dst->src[0];
168214
const float * src0_d = (const float *)src0->data;
@@ -222,3 +268,32 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
222268

223269
rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
224270
}
271+
272+
void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
273+
if (!dst->src[1]) {
274+
ggml_cuda_op_rms_norm(ctx, dst);
275+
return;
276+
}
277+
const ggml_tensor * src0 = dst->src[0];
278+
const ggml_tensor * src1 = dst->src[1];
279+
const float * src0_d = (const float *)src0->data;
280+
const float * src1_d = (const float *)src1->data;
281+
float * dst_d = (float *)dst->data;
282+
cudaStream_t stream = ctx.stream();
283+
284+
GGML_ASSERT(ggml_is_contiguous(src0));
285+
286+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
287+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
288+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
289+
GGML_ASSERT(src0->ne[0] == src1->ne[0]);
290+
GGML_ASSERT(ggml_nrows(src1) == 1);
291+
292+
const int64_t ne00 = src0->ne[0];
293+
const int64_t nrows = ggml_nrows(src0);
294+
295+
float eps;
296+
memcpy(&eps, dst->op_params, sizeof(float));
297+
298+
fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, stream);
299+
}

ggml/src/ggml-cuda/norm.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
55
void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
66

77
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
8+
9+
void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

ggml/src/ggml-metal.m

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
142142
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
143143
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
144144
GGML_METAL_KERNEL_TYPE_RMS_NORM,
145+
GGML_METAL_KERNEL_TYPE_FUSED_RMS_NORM,
145146
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
146147
GGML_METAL_KERNEL_TYPE_NORM,
147148
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
@@ -596,6 +597,7 @@ @implementation GGMLMetalClass
596597
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
597598
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
598599
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, support_simdgroup_reduction);
600+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FUSED_RMS_NORM, fused_rms_norm, support_simdgroup_reduction);
599601
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, support_simdgroup_reduction);
600602
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
601603
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
@@ -856,6 +858,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
856858
case GGML_OP_SUM_ROWS:
857859
case GGML_OP_SOFT_MAX:
858860
case GGML_OP_RMS_NORM:
861+
case GGML_OP_FUSED_RMS_NORM:
859862
case GGML_OP_GROUP_NORM:
860863
return support_simdgroup_reduction;
861864
case GGML_OP_NORM:
@@ -2440,6 +2443,38 @@ static void ggml_metal_encode_node(
24402443

24412444
const int64_t nrows = ggml_nrows(src0);
24422445

2446+
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2447+
} break;
2448+
case GGML_OP_FUSED_RMS_NORM:
2449+
{
2450+
GGML_ASSERT(ne00 % 4 == 0);
2451+
GGML_ASSERT(ggml_is_contiguous_1(src0));
2452+
GGML_ASSERT(src1->ne[0] == src0->ne[0]);
2453+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
2454+
GGML_ASSERT(ggml_nrows(src1) == 1);
2455+
2456+
float eps;
2457+
memcpy(&eps, dst->op_params, sizeof(float));
2458+
2459+
int nth = 32; // SIMD width
2460+
2461+
while (nth < ne00/4 && nth < 1024) {
2462+
nth *= 2;
2463+
}
2464+
2465+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FUSED_RMS_NORM].pipeline;
2466+
2467+
[encoder setComputePipelineState:pipeline];
2468+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2469+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2470+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2471+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
2472+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
2473+
[encoder setBytes:&eps length:sizeof( float) atIndex:5];
2474+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2475+
2476+
const int64_t nrows = ggml_nrows(src0);
2477+
24432478
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
24442479
} break;
24452480
case GGML_OP_GROUP_NORM:

ggml/src/ggml-metal.metal

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -969,6 +969,57 @@ kernel void kernel_rms_norm(
969969
}
970970
}
971971

972+
kernel void kernel_fused_rms_norm(
973+
device const void * src0,
974+
device const void * src1,
975+
device float * dst,
976+
constant int64_t & ne00,
977+
constant uint64_t & nb01,
978+
constant float & eps,
979+
threadgroup float * buf [[threadgroup(0)]],
980+
uint tgpig[[threadgroup_position_in_grid]],
981+
uint tpitg[[thread_position_in_threadgroup]],
982+
uint sgitg[[simdgroup_index_in_threadgroup]],
983+
uint tiisg[[thread_index_in_simdgroup]],
984+
uint ntg[[threads_per_threadgroup]]) {
985+
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
986+
987+
float4 sumf = 0;
988+
float all_sum = 0;
989+
990+
// parallel sum
991+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
992+
sumf += x[i00] * x[i00];
993+
}
994+
all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
995+
all_sum = simd_sum(all_sum);
996+
if (ntg > N_SIMDWIDTH) {
997+
if (sgitg == 0) {
998+
buf[tiisg] = 0.0f;
999+
}
1000+
1001+
threadgroup_barrier(mem_flags::mem_threadgroup);
1002+
1003+
if (tiisg == 0) {
1004+
buf[sgitg] = all_sum;
1005+
}
1006+
1007+
threadgroup_barrier(mem_flags::mem_threadgroup);
1008+
1009+
all_sum = buf[tiisg];
1010+
all_sum = simd_sum(all_sum);
1011+
}
1012+
1013+
const float mean = all_sum/ne00;
1014+
const float scale = 1.0f/sqrt(mean + eps);
1015+
1016+
device float4 * y = (device float4 *) (dst + tgpig*ne00);
1017+
device float4 * z = (device float4 *)src1;
1018+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
1019+
y[i00] = x[i00] * z[i00] * scale;
1020+
}
1021+
}
1022+
9721023
kernel void kernel_group_norm(
9731024
device const float * src0,
9741025
device float * dst,

0 commit comments

Comments
 (0)