Skip to content

Commit 7ab450d

Browse files
author
alexju
committed
metal : refactor arange parameters into a struct
1 parent 2cd68ad commit 7ab450d

File tree

3 files changed

+17
-10
lines changed

3 files changed

+17
-10
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,4 +485,10 @@ typedef struct {
485485
int32_t p1;
486486
} ggml_metal_kargs_pad_reflect_1d;
487487

488+
typedef struct {
489+
int64_t ne0;
490+
float start;
491+
float step;
492+
} ggml_metal_kargs_arange;
493+
488494
#endif // GGML_METAL_IMPL

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3527,13 +3527,16 @@ static void ggml_metal_encode_node(
35273527
memcpy(&step, ((const int32_t *) dst->op_params) + 2, sizeof(float));
35283528

35293529
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
3530+
3531+
ggml_metal_kargs_arange args = {
3532+
/*.ne0 =*/ ne0,
3533+
/*.start =*/ start,
3534+
/*.step =*/ step
3535+
};
35303536

3531-
// TODO: add ggml_metal_kargs struct
35323537
[encoder setComputePipelineState:pipeline];
3533-
[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
3534-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1];
3535-
[encoder setBytes:&start length:sizeof(start) atIndex:2];
3536-
[encoder setBytes:&step length:sizeof(step) atIndex:3];
3538+
[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
3539+
[encoder setBytes:&args length:sizeof(args) atIndex:1];
35373540

35383541
const int nth = MIN(1024, ne0);
35393542

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2755,17 +2755,15 @@ kernel void kernel_pad_reflect_1d_f32(
27552755

27562756
kernel void kernel_arange_f32(
27572757
device char * dst,
2758-
constant int64_t & ne0,
2759-
constant float & start,
2760-
constant float & step,
2758+
constant ggml_metal_kargs_arange & args,
27612759
uint3 tgpig[[threadgroup_position_in_grid]],
27622760
uint3 tpitg[[thread_position_in_threadgroup]],
27632761
uint3 ntg[[threads_per_threadgroup]]) {
27642762

27652763
device float * dst_ptr = (device float *) dst;
27662764

2767-
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
2768-
dst_ptr[i0] = start + step * i0;
2765+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
2766+
dst_ptr[i0] = args.start + args.step * i0;
27692767
}
27702768
}
27712769

0 commit comments

Comments
 (0)