@@ -2609,12 +2609,7 @@ typedef void (conv_transpose_1d_t)(
26092609 device const float * src0,
26102610 device const float * src1,
26112611 device char * dst,
2612- constant int32_t & IC,
2613- constant int32_t & IL,
2614- constant int32_t & K,
2615- constant int32_t & s0,
2616- constant uint64_t & nb0,
2617- constant uint64_t & nb1,
2612+ constant ggml_metal_kargs_conv_transpose_1d & args,
26182613 uint3 tgpig[[threadgroup_position_in_grid]],
26192614 uint3 tgpg[[threadgroups_per_grid]]);
26202615
@@ -2623,29 +2618,24 @@ kernel void kernel_conv_transpose_1d(
26232618 device const T * src0,
26242619 device const float * src1,
26252620 device char * dst,
2626- constant int32_t & IC,
2627- constant int32_t & IL,
2628- constant int32_t & K,
2629- constant int32_t & s0,
2630- constant uint64_t & nb0,
2631- constant uint64_t & nb1,
2621+ constant ggml_metal_kargs_conv_transpose_1d & args,
26322622 uint3 tgpig[[threadgroup_position_in_grid]],
26332623 uint3 tgpg[[threadgroups_per_grid]]) {
26342624
26352625 float v = 0 .0f ;
26362626
2637- for (int64_t c = 0 ; c < IC; c++) {
2638- const int32_t kernel_offset = c * tgpg[1 ] * K + K * tgpig[1 ];
2639- const int32_t input_offset = c * IL;
2627+ for (int64_t c = 0 ; c < args. IC ; c++) {
2628+ const int32_t kernel_offset = c * tgpg[1 ] * args. K + args. K * tgpig[1 ];
2629+ const int32_t input_offset = c * args. IL ;
26402630
2641- for (int64_t i = 0 ; i < IL; i++) {
2642- if (tgpig[0 ] >= i * s0 && tgpig[0 ] < i * s0 + K) {
2643- v += src0[kernel_offset + tgpig[0 ] - i * s0] * src1[input_offset + i];
2631+ for (int64_t i = 0 ; i < args. IL ; i++) {
2632+ if (tgpig[0 ] >= i * args. s0 && tgpig[0 ] < i * args. s0 + args. K ) {
2633+ v += src0[kernel_offset + tgpig[0 ] - i * args. s0 ] * src1[input_offset + i];
26442634 }
26452635 }
26462636 }
26472637
2648- device float * dst_ptr = (device float *) (dst + tgpig[0 ] * nb0 + tgpig[1 ] * nb1);
2638+ device float * dst_ptr = (device float *) (dst + tgpig[0 ] * args. nb0 + tgpig[1 ] * args. nb1 );
26492639
26502640 dst_ptr[0 ] = v;
26512641}
@@ -2655,12 +2645,7 @@ kernel void kernel_conv_transpose_1d<float>(
26552645 device const float * src0,
26562646 device const float * src1,
26572647 device char * dst,
2658- constant int32_t & IC,
2659- constant int32_t & IL,
2660- constant int32_t & K,
2661- constant int32_t & s0,
2662- constant uint64_t & nb0,
2663- constant uint64_t & nb1,
2648+ constant ggml_metal_kargs_conv_transpose_1d & args,
26642649 uint3 tgpig[[threadgroup_position_in_grid]],
26652650 uint3 tgpg[[threadgroups_per_grid]]);
26662651
@@ -2669,12 +2654,7 @@ kernel void kernel_conv_transpose_1d<half>(
26692654 device const half * src0,
26702655 device const float * src1,
26712656 device char * dst,
2672- constant int32_t & IC,
2673- constant int32_t & IL,
2674- constant int32_t & K,
2675- constant int32_t & s0,
2676- constant uint64_t & nb0,
2677- constant uint64_t & nb1,
2657+ constant ggml_metal_kargs_conv_transpose_1d & args,
26782658 uint3 tgpig[[threadgroup_position_in_grid]],
26792659 uint3 tgpg[[threadgroups_per_grid]]);
26802660
0 commit comments