File tree Expand file tree Collapse file tree 3 files changed +19
-12
lines changed Expand file tree Collapse file tree 3 files changed +19
-12
lines changed Original file line number Diff line number Diff line change @@ -485,6 +485,12 @@ typedef struct {
485485 int32_t p1 ;
486486} ggml_metal_kargs_pad_reflect_1d ;
487487
488+ typedef struct {
489+ uint64_t nb1 ;
490+ int dim ;
491+ int max_period ;
492+ } ggml_metal_kargs_timestep_embedding ;
493+
488494typedef struct {
489495 int64_t ne0 ;
490496 float start ;
Original file line number Diff line number Diff line change @@ -3553,13 +3553,16 @@ static void ggml_metal_encode_node(
35533553
35543554 id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline ;
35553555
3556- // TODO: add ggml_metal_kargs struct
3556+ ggml_metal_kargs_timestep_embedding args = {
3557+ /* .nb1 =*/ nb1,
3558+ /* .dim =*/ dim,
3559+ /* .max_period =*/ max_period
3560+ };
3561+
35573562 [encoder setComputePipelineState: pipeline];
35583563 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
35593564 [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
3560- [encoder setBytes: &nb1 length: sizeof (nb1) atIndex: 2 ];
3561- [encoder setBytes: &dim length: sizeof (dim) atIndex: 3 ];
3562- [encoder setBytes: &max_period length: sizeof (max_period) atIndex: 4 ];
3565+ [encoder setBytes: &args length: sizeof (args) atIndex: 2 ];
35633566
35643567 const int nth = MIN (1024 , half);
35653568
Original file line number Diff line number Diff line change @@ -2770,27 +2770,25 @@ kernel void kernel_arange_f32(
27702770kernel void kernel_timestep_embedding_f32 (
27712771 device const char * src0,
27722772 device char * dst,
2773- constant uint64_t & nb1,
2774- constant int & dim,
2775- constant int & max_period,
2773+ constant ggml_metal_kargs_timestep_embedding & args,
27762774 uint3 tgpig[[threadgroup_position_in_grid]],
27772775 uint3 tpitg[[thread_position_in_threadgroup]],
27782776 uint3 ntg[[threads_per_threadgroup]]) {
27792777
27802778 int i = tgpig.x ;
2781- device float * embed_data = (device float *)(dst + i* nb1);
2779+ device float * embed_data = (device float *)(dst + i*args. nb1 );
27822780
2783- int half_ = dim / 2 ;
2781+ int half_ = args. dim / 2 ;
27842782 for (int j = tpitg.x ; j < half_; j += ntg.x ) {
27852783 float timestep = ((device float *)src0)[i];
2786- float freq = (float )exp (-log ((float )max_period) * j / half_);
2784+ float freq = (float )exp (-log ((float )args. max_period ) * j / half_);
27872785 float arg = timestep * freq;
27882786 embed_data[j ] = cos (arg);
27892787 embed_data[j + half_] = sin (arg);
27902788 }
27912789
2792- if (dim % 2 != 0 && tpitg.x == 0 ) {
2793- embed_data[dim] = 0 .f ;
2790+ if (args. dim % 2 != 0 && tpitg.x == 0 ) {
2791+ embed_data[args. dim ] = 0 .f ;
27942792 }
27952793}
27962794
You can’t perform that action at this time.
0 commit comments