Skip to content

Commit 08bad56

Browse files
author
alexju
committed
metal : refactor group_norm parameters into a struct
1 parent 9c4dbad commit 08bad56

File tree

3 files changed

+27
-20
lines changed

3 files changed

+27
-20
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,17 @@ typedef struct {
285285
float eps;
286286
} ggml_metal_kargs_rms_norm;
287287

288+
typedef struct {
289+
int64_t ne00;
290+
int64_t ne01;
291+
int64_t ne02;
292+
uint64_t nb00;
293+
uint64_t nb01;
294+
uint64_t nb02;
295+
int32_t n_groups;
296+
float eps;
297+
} ggml_metal_kargs_group_norm;
298+
288299
typedef struct {
289300
uint64_t ofs0;
290301
uint64_t ofs1;

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3125,18 +3125,21 @@ static void ggml_metal_encode_node(
31253125

31263126
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
31273127

3128-
// TODO: add ggml_metal_kargs struct
3128+
ggml_metal_kargs_group_norm args = {
3129+
/*.ne00 =*/ ne00,
3130+
/*.ne01 =*/ ne01,
3131+
/*.ne02 =*/ ne02,
3132+
/*.nb00 =*/ nb00,
3133+
/*.nb01 =*/ nb01,
3134+
/*.nb02 =*/ nb02,
3135+
/*.n_groups =*/ n_groups,
3136+
/*.eps =*/ eps,
3137+
};
3138+
31293139
[encoder setComputePipelineState:pipeline];
31303140
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
31313141
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3132-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
3133-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
3134-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
3135-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
3136-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
3137-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
3138-
[encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
3139-
[encoder setBytes:&eps length:sizeof( float) atIndex:9];
3142+
[encoder setBytes:&args length:sizeof(args) atIndex:2];
31403143
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
31413144

31423145
[encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1467,22 +1467,15 @@ kernel void kernel_rms_norm(
14671467
kernel void kernel_group_norm(
14681468
device const float * src0,
14691469
device float * dst,
1470-
constant int64_t & ne00,
1471-
constant int64_t & ne01,
1472-
constant int64_t & ne02,
1473-
constant uint64_t & nb00,
1474-
constant uint64_t & nb01,
1475-
constant uint64_t & nb02,
1476-
constant int32_t & n_groups,
1477-
constant float & eps,
1470+
constant ggml_metal_kargs_group_norm & args,
14781471
threadgroup float * buf [[threadgroup(0)]],
14791472
uint tgpig[[threadgroup_position_in_grid]],
14801473
uint tpitg[[thread_position_in_threadgroup]],
14811474
uint sgitg[[simdgroup_index_in_threadgroup]],
14821475
uint tiisg[[thread_index_in_simdgroup]],
14831476
uint ntg[[threads_per_threadgroup]]) {
1484-
const int64_t ne = ne00*ne01*ne02;
1485-
const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
1477+
const int64_t ne = args.ne00*args.ne01*args.ne02;
1478+
const int64_t gs = args.ne00*args.ne01*((args.ne02 + args.n_groups - 1) / args.n_groups);
14861479

14871480
int start = tgpig * gs;
14881481
int end = start + gs;
@@ -1546,7 +1539,7 @@ kernel void kernel_group_norm(
15461539
}
15471540

15481541
const float variance = tmp / gs;
1549-
const float scale = 1.0f/sqrt(variance + eps);
1542+
const float scale = 1.0f/sqrt(variance + args.eps);
15501543
for (int j = start; j < end; j += ntg) {
15511544
dst[j] *= scale;
15521545
}

0 commit comments

Comments
 (0)