File tree Expand file tree Collapse file tree 3 files changed +11
-4
lines changed Expand file tree Collapse file tree 3 files changed +11
-4
lines changed Original file line number Diff line number Diff line change @@ -491,6 +491,10 @@ typedef struct {
491491 int max_period ;
492492} ggml_metal_kargs_timestep_embedding ;
493493
494+ typedef struct {
495+ float slope ;
496+ } ggml_metal_kargs_leaky_relu ;
497+
494498typedef struct {
495499 int64_t ncols ;
496500 int64_t ncols_pad ;
Original file line number Diff line number Diff line change @@ -3617,11 +3617,14 @@ static void ggml_metal_encode_node(
36173617
36183618 id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline ;
36193619
3620- // TODO: add ggml_metal_kargs struct
3620+ ggml_metal_kargs_leaky_relu args = {
3621+ /* .slope =*/ slope
3622+ };
3623+
36213624 [encoder setComputePipelineState: pipeline];
36223625 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
36233626 [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
3624- [encoder setBytes: &slope length: sizeof (slope) atIndex: 2 ];
3627+ [encoder setBytes: &args length: sizeof (args) atIndex: 2 ];
36253628
36263629 const int64_t n = ggml_nelements (dst);
36273630
Original file line number Diff line number Diff line change @@ -2861,9 +2861,9 @@ template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_ar
28612861kernel void kernel_leaky_relu_f32 (
28622862 device const float * src0,
28632863 device float * dst,
2864- constant float & slope ,
2864+ constant ggml_metal_kargs_leaky_relu & args ,
28652865 uint tpig[[thread_position_in_grid]]) {
2866- dst[tpig] = src0[tpig] > 0 .0f ? src0[tpig] : src0[tpig] * slope;
2866+ dst[tpig] = src0[tpig] > 0 .0f ? src0[tpig] : src0[tpig] * args. slope ;
28672867}
28682868
28692869// ref: https://arxiv.org/pdf/2307.08691.pdf
You can’t perform that action at this time.
0 commit comments