Skip to content

Commit cd3dcdb

Browse files
author
alexju
committed
metal : refactor sum_rows parameters into a struct
1 parent f8d1bf2 commit cd3dcdb

File tree

3 files changed

+61
-53
lines changed

3 files changed

+61
-53
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,4 +303,31 @@ typedef struct {
303303
int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
304304
} ggml_metal_kargs_im2col;
305305

306+
typedef struct {
307+
int64_t ne00;
308+
int64_t ne01;
309+
int64_t ne02;
310+
int64_t ne03;
311+
uint64_t nb00;
312+
uint64_t nb01;
313+
uint64_t nb02;
314+
uint64_t nb03;
315+
int64_t ne10;
316+
int64_t ne11;
317+
int64_t ne12;
318+
int64_t ne13;
319+
uint64_t nb10;
320+
uint64_t nb11;
321+
uint64_t nb12;
322+
uint64_t nb13;
323+
int64_t ne0;
324+
int64_t ne1;
325+
int64_t ne2;
326+
int64_t ne3;
327+
uint64_t nb0;
328+
uint64_t nb1;
329+
uint64_t nb2;
330+
uint64_t nb3;
331+
} ggml_metal_kargs_sum_rows;
332+
306333
#endif // GGML_METAL_IMPL

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1944,34 +1944,38 @@ static void ggml_metal_encode_node(
19441944

19451945
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
19461946

1947-
// TODO: add ggml_metal_kargs struct
1947+
1948+
ggml_metal_kargs_sum_rows args = {
1949+
/*.ne00 =*/ ne00,
1950+
/*.ne01 =*/ ne01,
1951+
/*.ne02 =*/ ne02,
1952+
/*.ne03 =*/ ne03,
1953+
/*.nb00 =*/ nb00,
1954+
/*.nb01 =*/ nb01,
1955+
/*.nb02 =*/ nb02,
1956+
/*.nb03 =*/ nb03,
1957+
/*.ne10 =*/ ne10,
1958+
/*.ne11 =*/ ne11,
1959+
/*.ne12 =*/ ne12,
1960+
/*.ne13 =*/ ne13,
1961+
/*.nb10 =*/ nb10,
1962+
/*.nb11 =*/ nb11,
1963+
/*.nb12 =*/ nb12,
1964+
/*.nb13 =*/ nb13,
1965+
/*.ne0 =*/ ne0,
1966+
/*.ne1 =*/ ne1,
1967+
/*.ne2 =*/ ne2,
1968+
/*.ne3 =*/ ne3,
1969+
/*.nb0 =*/ nb0,
1970+
/*.nb1 =*/ nb1,
1971+
/*.nb2 =*/ nb2,
1972+
/*.nb3 =*/ nb3,
1973+
};
1974+
19481975
[encoder setComputePipelineState:pipeline];
19491976
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
19501977
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1951-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1952-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1953-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1954-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1955-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1956-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1957-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1958-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1959-
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1960-
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1961-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1962-
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1963-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1964-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1965-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1966-
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
1967-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
1968-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
1969-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
1970-
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
1971-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
1972-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
1973-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
1974-
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
1978+
[encoder setBytes:&args length:sizeof(args) atIndex:2];
19751979

19761980
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
19771981
} break;

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 5 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -948,45 +948,22 @@ kernel void kernel_cos(
948948
kernel void kernel_sum_rows(
949949
device const float * src0,
950950
device float * dst,
951-
constant int64_t & ne00,
952-
constant int64_t & ne01,
953-
constant int64_t & ne02,
954-
constant int64_t & ne03,
955-
constant uint64_t & nb00,
956-
constant uint64_t & nb01,
957-
constant uint64_t & nb02,
958-
constant uint64_t & nb03,
959-
constant int64_t & ne10,
960-
constant int64_t & ne11,
961-
constant int64_t & ne12,
962-
constant int64_t & ne13,
963-
constant uint64_t & nb10,
964-
constant uint64_t & nb11,
965-
constant uint64_t & nb12,
966-
constant uint64_t & nb13,
967-
constant int64_t & ne0,
968-
constant int64_t & ne1,
969-
constant int64_t & ne2,
970-
constant int64_t & ne3,
971-
constant uint64_t & nb0,
972-
constant uint64_t & nb1,
973-
constant uint64_t & nb2,
974-
constant uint64_t & nb3,
951+
constant ggml_metal_kargs_sum_rows & args,
975952
uint3 tpig[[thread_position_in_grid]]) {
976953
int64_t i3 = tpig.z;
977954
int64_t i2 = tpig.y;
978955
int64_t i1 = tpig.x;
979956

980-
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
957+
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
981958
return;
982959
}
983960

984-
device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
985-
device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
961+
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
962+
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
986963

987964
float row_sum = 0;
988965

989-
for (int64_t i0 = 0; i0 < ne00; i0++) {
966+
for (int64_t i0 = 0; i0 < args.ne00; i0++) {
990967
row_sum += src_row[i0];
991968
}
992969

0 commit comments

Comments
 (0)