Skip to content

Commit 53d36ee

Browse files
author
alexju
committed
metal : refactor timestep_embedding parameters into a struct
1 parent 7ab450d commit 53d36ee

File tree

3 files changed

+19
-12
lines changed

3 files changed

+19
-12
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,12 @@ typedef struct {
485485
int32_t p1;
486486
} ggml_metal_kargs_pad_reflect_1d;
487487

488+
typedef struct {
489+
uint64_t nb1;
490+
int dim;
491+
int max_period;
492+
} ggml_metal_kargs_timestep_embedding;
493+
488494
typedef struct {
489495
int64_t ne0;
490496
float start;

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3553,13 +3553,16 @@ static void ggml_metal_encode_node(
35533553

35543554
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
35553555

3556-
// TODO: add ggml_metal_kargs struct
3556+
ggml_metal_kargs_timestep_embedding args = {
3557+
/*.nb1 =*/ nb1,
3558+
/*.dim =*/ dim,
3559+
/*.max_period =*/ max_period
3560+
};
3561+
35573562
[encoder setComputePipelineState:pipeline];
35583563
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
35593564
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3560-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2];
3561-
[encoder setBytes:&dim length:sizeof(dim) atIndex:3];
3562-
[encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
3565+
[encoder setBytes:&args length:sizeof(args) atIndex:2];
35633566

35643567
const int nth = MIN(1024, half);
35653568

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2770,27 +2770,25 @@ kernel void kernel_arange_f32(
27702770
kernel void kernel_timestep_embedding_f32(
27712771
device const char * src0,
27722772
device char * dst,
2773-
constant uint64_t & nb1,
2774-
constant int & dim,
2775-
constant int & max_period,
2773+
constant ggml_metal_kargs_timestep_embedding & args,
27762774
uint3 tgpig[[threadgroup_position_in_grid]],
27772775
uint3 tpitg[[thread_position_in_threadgroup]],
27782776
uint3 ntg[[threads_per_threadgroup]]) {
27792777

27802778
int i = tgpig.x;
2781-
device float * embed_data = (device float *)(dst + i*nb1);
2779+
device float * embed_data = (device float *)(dst + i*args.nb1);
27822780

2783-
int half_ = dim / 2;
2781+
int half_ = args.dim / 2;
27842782
for (int j = tpitg.x; j < half_; j += ntg.x) {
27852783
float timestep = ((device float *)src0)[i];
2786-
float freq = (float)exp(-log((float)max_period) * j / half_);
2784+
float freq = (float)exp(-log((float)args.max_period) * j / half_);
27872785
float arg = timestep * freq;
27882786
embed_data[j ] = cos(arg);
27892787
embed_data[j + half_] = sin(arg);
27902788
}
27912789

2792-
if (dim % 2 != 0 && tpitg.x == 0) {
2793-
embed_data[dim] = 0.f;
2790+
if (args.dim % 2 != 0 && tpitg.x == 0) {
2791+
embed_data[args.dim] = 0.f;
27942792
}
27952793
}
27962794

0 commit comments

Comments
 (0)