Skip to content

Commit 64e4a7e

Browse files
author
alexju
committed
metal : refactor leaky_relu parameters into a struct
1 parent e19266b commit 64e4a7e

File tree

3 files changed

+11
-4
lines changed

3 files changed

+11
-4
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,10 @@ typedef struct {
491491
int max_period;
492492
} ggml_metal_kargs_timestep_embedding;
493493

494+
typedef struct {
495+
float slope;
496+
} ggml_metal_kargs_leaky_relu;
497+
494498
typedef struct {
495499
int64_t ncols;
496500
int64_t ncols_pad;

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3617,11 +3617,14 @@ static void ggml_metal_encode_node(
36173617

36183618
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
36193619

3620-
// TODO: add ggml_metal_kargs struct
3620+
ggml_metal_kargs_leaky_relu args = {
3621+
/*.slope =*/ slope
3622+
};
3623+
36213624
[encoder setComputePipelineState:pipeline];
36223625
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
36233626
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3624-
[encoder setBytes:&slope length:sizeof(slope) atIndex:2];
3627+
[encoder setBytes:&args length:sizeof(args) atIndex:2];
36253628

36263629
const int64_t n = ggml_nelements(dst);
36273630

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2861,9 +2861,9 @@ template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_ar
28612861
kernel void kernel_leaky_relu_f32(
28622862
device const float * src0,
28632863
device float * dst,
2864-
constant float & slope,
2864+
constant ggml_metal_kargs_leaky_relu & args,
28652865
uint tpig[[thread_position_in_grid]]) {
2866-
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
2866+
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * args.slope;
28672867
}
28682868

28692869
// ref: https://arxiv.org/pdf/2307.08691.pdf

0 commit comments

Comments
 (0)