Skip to content

Commit 9c4dbad

Browse files
author
alexju
committed
metal : refactor get_rows parameters into a struct
1 parent 678b2e7 commit 9c4dbad

File tree

3 files changed

+38
-45
lines changed

3 files changed

+38
-45
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,4 +391,15 @@ typedef struct {
391391
uint64_t nb52;
392392
} ggml_metal_kargs_ssm_scan;
393393

394+
typedef struct {
395+
int64_t ne00;
396+
uint64_t nb01;
397+
uint64_t nb02;
398+
int64_t ne10;
399+
uint64_t nb10;
400+
uint64_t nb11;
401+
uint64_t nb1;
402+
uint64_t nb2;
403+
} ggml_metal_kargs_get_rows;
404+
394405
#endif // GGML_METAL_IMPL

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3053,19 +3053,22 @@ static void ggml_metal_encode_node(
30533053
default: GGML_ABORT("not implemented");
30543054
}
30553055

3056-
// TODO: add ggml_metal_kargs struct
3056+
ggml_metal_kargs_get_rows args = {
3057+
/*.ne00 =*/ ne00,
3058+
/*.nb01 =*/ nb01,
3059+
/*.nb02 =*/ nb02,
3060+
/*.ne10 =*/ ne10,
3061+
/*.nb10 =*/ nb10,
3062+
/*.nb11 =*/ nb11,
3063+
/*.nb1 =*/ nb1,
3064+
/*.nb2 =*/ nb2,
3065+
};
3066+
30573067
[encoder setComputePipelineState:pipeline];
30583068
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
30593069
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
30603070
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
3061-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
3062-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
3063-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
3064-
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
3065-
[encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
3066-
[encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
3067-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
3068-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
3071+
[encoder setBytes:&args length:sizeof(args) atIndex:3];
30693072

30703073
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
30713074
} break;

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 15 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5885,28 +5885,21 @@ kernel void kernel_get_rows_q(
58855885
device const void * src0,
58865886
device const void * src1,
58875887
device float * dst,
5888-
constant int64_t & ne00,
5889-
constant uint64_t & nb01,
5890-
constant uint64_t & nb02,
5891-
constant int64_t & ne10,
5892-
constant uint64_t & nb10,
5893-
constant uint64_t & nb11,
5894-
constant uint64_t & nb1,
5895-
constant uint64_t & nb2,
5888+
constant ggml_metal_kargs_get_rows & args,
58965889
uint3 tgpig[[threadgroup_position_in_grid]],
58975890
uint tiitg[[thread_index_in_threadgroup]],
58985891
uint3 tptg [[threads_per_threadgroup]]) {
58995892
const int64_t i10 = tgpig.x;
59005893
const int64_t i11 = tgpig.y;
59015894

5902-
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
5895+
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
59035896

59045897
const int64_t i02 = i11;
59055898

5906-
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
5899+
for (int64_t ind = tiitg; ind < args.ne00/16; ind += tptg.x) {
59075900
float4x4 temp;
5908-
dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
5909-
*(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
5901+
dequantize_func(((device const block_q *) ((const device char *) src0 + r*args.nb01 + i02*args.nb02)) + ind/nl, ind%nl, temp);
5902+
*(((device float4x4 *) ((device char *) dst + i11*args.nb2 + i10*args.nb1)) + ind) = temp;
59105903
}
59115904
}
59125905

@@ -5915,55 +5908,41 @@ kernel void kernel_get_rows_f(
59155908
device const void * src0,
59165909
device const void * src1,
59175910
device float * dst,
5918-
constant int64_t & ne00,
5919-
constant uint64_t & nb01,
5920-
constant uint64_t & nb02,
5921-
constant int64_t & ne10,
5922-
constant uint64_t & nb10,
5923-
constant uint64_t & nb11,
5924-
constant uint64_t & nb1,
5925-
constant uint64_t & nb2,
5911+
constant ggml_metal_kargs_get_rows & args,
59265912
uint3 tgpig[[threadgroup_position_in_grid]],
59275913
uint tiitg[[thread_index_in_threadgroup]],
59285914
uint3 tptg [[threads_per_threadgroup]]) {
59295915
const int64_t i10 = tgpig.x;
59305916
const int64_t i11 = tgpig.y;
59315917

5932-
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
5918+
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
59335919

59345920
const int64_t i02 = i11;
59355921

5936-
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
5937-
(( device float *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
5938-
((const device T *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
5922+
for (int ind = tiitg; ind < args.ne00; ind += tptg.x) {
5923+
(( device float *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] =
5924+
((const device T *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind];
59395925
}
59405926
}
59415927

59425928
kernel void kernel_get_rows_i32(
59435929
device const void * src0,
59445930
device const void * src1,
59455931
device int32_t * dst,
5946-
constant int64_t & ne00,
5947-
constant uint64_t & nb01,
5948-
constant uint64_t & nb02,
5949-
constant int64_t & ne10,
5950-
constant uint64_t & nb10,
5951-
constant uint64_t & nb11,
5952-
constant uint64_t & nb1,
5953-
constant uint64_t & nb2,
5932+
constant ggml_metal_kargs_get_rows & args,
59545933
uint3 tgpig[[threadgroup_position_in_grid]],
59555934
uint tiitg[[thread_index_in_threadgroup]],
59565935
uint3 tptg [[threads_per_threadgroup]]) {
59575936
const int64_t i10 = tgpig.x;
59585937
const int64_t i11 = tgpig.y;
59595938

5960-
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
5939+
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
59615940

59625941
const int64_t i02 = i11;
59635942

5964-
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
5965-
(( device int32_t *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
5966-
((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
5943+
for (int ind = tiitg; ind < args.ne00; ind += tptg.x) {
5944+
(( device int32_t *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] =
5945+
((const device int32_t *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind];
59675946
}
59685947
}
59695948

0 commit comments

Comments
 (0)