File tree Expand file tree Collapse file tree 3 files changed +17
-10
lines changed Expand file tree Collapse file tree 3 files changed +17
-10
lines changed Original file line number Diff line number Diff line change @@ -485,4 +485,10 @@ typedef struct {
485485 int32_t p1 ;
486486} ggml_metal_kargs_pad_reflect_1d ;
487487
488+ typedef struct {
489+ int64_t ne0 ;
490+ float start ;
491+ float step ;
492+ } ggml_metal_kargs_arange ;
493+
488494#endif // GGML_METAL_IMPL
Original file line number Diff line number Diff line change @@ -3527,13 +3527,16 @@ static void ggml_metal_encode_node(
35273527 memcpy (&step, ((const int32_t *) dst->op_params ) + 2 , sizeof (float ));
35283528
35293529 id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline ;
3530+
3531+ ggml_metal_kargs_arange args = {
3532+ /* .ne0 =*/ ne0,
3533+ /* .start =*/ start,
3534+ /* .step =*/ step
3535+ };
35303536
3531- // TODO: add ggml_metal_kargs struct
35323537 [encoder setComputePipelineState: pipeline];
3533- [encoder setBuffer: id_dst offset: offs_dst atIndex: 0 ];
3534- [encoder setBytes: &ne0 length: sizeof (ne0) atIndex: 1 ];
3535- [encoder setBytes: &start length: sizeof (start) atIndex: 2 ];
3536- [encoder setBytes: &step length: sizeof (step) atIndex: 3 ];
3538+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 0 ];
3539+ [encoder setBytes: &args length: sizeof (args) atIndex: 1 ];
35373540
35383541 const int nth = MIN (1024 , ne0);
35393542
Original file line number Diff line number Diff line change @@ -2755,17 +2755,15 @@ kernel void kernel_pad_reflect_1d_f32(
27552755
27562756kernel void kernel_arange_f32 (
27572757 device char * dst,
2758- constant int64_t & ne0,
2759- constant float & start,
2760- constant float & step,
2758+ constant ggml_metal_kargs_arange & args,
27612759 uint3 tgpig[[threadgroup_position_in_grid]],
27622760 uint3 tpitg[[thread_position_in_threadgroup]],
27632761 uint3 ntg[[threads_per_threadgroup]]) {
27642762
27652763 device float * dst_ptr = (device float *) dst;
27662764
2767- for (int i0 = tpitg.x ; i0 < ne0; i0 += ntg.x ) {
2768- dst_ptr[i0] = start + step * i0;
2765+ for (int i0 = tpitg.x ; i0 < args. ne0 ; i0 += ntg.x ) {
2766+ dst_ptr[i0] = args. start + args. step * i0;
27692767 }
27702768}
27712769
You can’t perform that action at this time.
0 commit comments