Skip to content

Commit b438ff7

Browse files
committed
metal : GGML_OP_RMS_NORM
1 parent 2b86f84 commit b438ff7

File tree

3 files changed

+51
-41
lines changed

3 files changed

+51
-41
lines changed

ggml/src/ggml-common.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,13 @@ typedef struct {
642642
int32_t ne1;
643643
uint64_t nb1;
644644
} ggml_metal_kargs_mul_mv_id;
645+
646+
typedef struct {
647+
int32_t ne00;
648+
int32_t ne00_4;
649+
uint64_t nb01;
650+
float eps;
651+
} ggml_metal_kargs_rms_norm;
645652
#endif
646653

647654
#endif // GGML_COMMON_DECL

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2622,20 +2622,28 @@ static void ggml_metal_encode_node(
26222622
float eps;
26232623
memcpy(&eps, dst->op_params, sizeof(float));
26242624

2625+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
2626+
26252627
int nth = 32; // SIMD width
26262628

2627-
while (nth < ne00/4 && nth < 1024) {
2629+
while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
26282630
nth *= 2;
26292631
}
26302632

2631-
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
2633+
nth = MIN(nth, ne00/4);
2634+
2635+
ggml_metal_kargs_rms_norm args = {
2636+
/*.ne00 =*/ ne00,
2637+
/*.ne00_4 =*/ ne00/4,
2638+
/*.nb01 =*/ nb01,
2639+
/*.eps =*/ eps,
2640+
};
26322641

26332642
[encoder setComputePipelineState:pipeline];
2634-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2635-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2636-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2637-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
2638-
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
2643+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
2644+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2645+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2646+
26392647
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
26402648

26412649
const int64_t nrows = ggml_nrows(src0);

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 29 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1293,50 +1293,45 @@ kernel void kernel_norm(
12931293
}
12941294

12951295
kernel void kernel_rms_norm(
1296-
device const void * src0,
1297-
device float * dst,
1298-
constant int64_t & ne00,
1299-
constant uint64_t & nb01,
1300-
constant float & eps,
1301-
threadgroup float * buf [[threadgroup(0)]],
1302-
uint tgpig[[threadgroup_position_in_grid]],
1303-
uint tpitg[[thread_position_in_threadgroup]],
1304-
uint sgitg[[simdgroup_index_in_threadgroup]],
1305-
uint tiisg[[thread_index_in_simdgroup]],
1306-
uint ntg[[threads_per_threadgroup]]) {
1307-
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
1296+
constant ggml_metal_kargs_rms_norm & args,
1297+
device const char * src0,
1298+
device char * dst,
1299+
threadgroup float * shmem_f32 [[threadgroup(0)]],
1300+
uint tgpig[[threadgroup_position_in_grid]],
1301+
ushort tpitg[[thread_position_in_threadgroup]],
1302+
ushort sgitg[[simdgroup_index_in_threadgroup]],
1303+
ushort tiisg[[thread_index_in_simdgroup]],
1304+
ushort ntg[[threads_per_threadgroup]]) {
1305+
if (sgitg == 0) {
1306+
shmem_f32[tiisg] = 0.0f;
1307+
}
13081308

1309-
float4 sumf = 0;
1310-
float all_sum = 0;
1309+
device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
1310+
1311+
float sumf = 0.0f;
13111312

13121313
// parallel sum
1313-
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
1314-
sumf += x[i00] * x[i00];
1314+
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
1315+
sumf += dot(x[i00], x[i00]);
13151316
}
1316-
all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
1317-
all_sum = simd_sum(all_sum);
1318-
if (ntg > N_SIMDWIDTH) {
1319-
if (sgitg == 0) {
1320-
buf[tiisg] = 0.0f;
1321-
}
1317+
sumf = simd_sum(sumf);
13221318

1323-
threadgroup_barrier(mem_flags::mem_threadgroup);
1319+
threadgroup_barrier(mem_flags::mem_threadgroup);
13241320

1325-
if (tiisg == 0) {
1326-
buf[sgitg] = all_sum;
1327-
}
1321+
if (tiisg == 0) {
1322+
shmem_f32[sgitg] = sumf;
1323+
}
13281324

1329-
threadgroup_barrier(mem_flags::mem_threadgroup);
1325+
threadgroup_barrier(mem_flags::mem_threadgroup);
13301326

1331-
all_sum = buf[tiisg];
1332-
all_sum = simd_sum(all_sum);
1333-
}
1327+
sumf = shmem_f32[tiisg];
1328+
sumf = simd_sum(sumf);
13341329

1335-
const float mean = all_sum/ne00;
1336-
const float scale = 1.0f/sqrt(mean + eps);
1330+
const float mean = sumf/args.ne00;
1331+
const float scale = 1.0f/sqrt(mean + args.eps);
13371332

1338-
device float4 * y = (device float4 *) (dst + tgpig*ne00);
1339-
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
1333+
device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
1334+
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
13401335
y[i00] = x[i00] * scale;
13411336
}
13421337
}

0 commit comments

Comments
 (0)