@@ -6463,98 +6463,78 @@ template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t
64636463kernel void kernel_pool_2d_max_f32 (
64646464 device const float * src0,
64656465 device float * dst,
6466- constant int32_t & k0,
6467- constant int32_t & k1,
6468- constant int32_t & s0,
6469- constant int32_t & s1,
6470- constant int32_t & p0,
6471- constant int32_t & p1,
6472- constant int64_t & IH,
6473- constant int64_t & IW,
6474- constant int64_t & OH,
6475- constant int64_t & OW,
6476- constant int64_t & parallel_elements,
6466+ constant ggml_metal_kargs_pool_2d & args,
64776467 uint gid[[thread_position_in_grid]]) {
64786468
6479- if (gid >= parallel_elements) {
6469+ if (gid >= args. parallel_elements ) {
64806470 return ;
64816471 }
64826472
64836473 const int idx = gid;
6484- const int I_HW = IH * IW;
6485- const int O_HW = OH * OW;
6474+ const int I_HW = args. IH * args. IW ;
6475+ const int O_HW = args. OH * args. OW ;
64866476 const int nc = idx / O_HW;
6487- const int cur_oh = idx % O_HW / OW;
6488- const int cur_ow = idx % O_HW % OW;
6477+ const int cur_oh = idx % O_HW / args. OW ;
6478+ const int cur_ow = idx % O_HW % args. OW ;
64896479
64906480 device const float * i_ptr = src0 + nc * I_HW;
64916481 device float * o_ptr = dst + nc * O_HW;
64926482
6493- const int start_h = cur_oh * s1 - p1;
6483+ const int start_h = cur_oh * args. s1 - args. p1 ;
64946484 const int bh = MAX (0 , start_h);
6495- const int eh = MIN (IH, start_h + k1);
6496- const int start_w = cur_ow * s0 - p0;
6485+ const int eh = MIN (args. IH , start_h + args. k1 );
6486+ const int start_w = cur_ow * args. s0 - args. p0 ;
64976487 const int bw = MAX (0 , start_w);
6498- const int ew = MIN (IW, start_w + k0);
6488+ const int ew = MIN (args. IW , start_w + args. k0 );
64996489
65006490 float res = -INFINITY;
65016491
65026492 for (int i = bh; i < eh; i += 1 ) {
65036493 for (int j = bw; j < ew; j += 1 ) {
6504- res = MAX (res, i_ptr[i * IW + j]);
6494+ res = MAX (res, i_ptr[i * args. IW + j]);
65056495 }
65066496 }
65076497
6508- o_ptr[cur_oh * OW + cur_ow] = res;
6498+ o_ptr[cur_oh * args. OW + cur_ow] = res;
65096499}
65106500
65116501kernel void kernel_pool_2d_avg_f32 (
65126502 device const float * src0,
65136503 device float * dst,
6514- constant int32_t & k0,
6515- constant int32_t & k1,
6516- constant int32_t & s0,
6517- constant int32_t & s1,
6518- constant int32_t & p0,
6519- constant int32_t & p1,
6520- constant int64_t & IH,
6521- constant int64_t & IW,
6522- constant int64_t & OH,
6523- constant int64_t & OW,
6524- constant int64_t & parallel_elements,
6504+ constant ggml_metal_kargs_pool_2d & args,
65256505 uint gid[[thread_position_in_grid]]) {
65266506
6527- if (gid >= parallel_elements) {
6507+ if (gid >= args. parallel_elements ) {
65286508 return ;
65296509 }
65306510
65316511 const int idx = gid;
6532- const int I_HW = IH * IW;
6533- const int O_HW = OH * OW;
6512+ const int I_HW = args. IH * args. IW ;
6513+ const int O_HW = args. OH * args. OW ;
65346514 const int nc = idx / O_HW;
6535- const int cur_oh = idx % O_HW / OW;
6536- const int cur_ow = idx % O_HW % OW;
6515+ const int cur_oh = idx % O_HW / args. OW ;
6516+ const int cur_ow = idx % O_HW % args. OW ;
65376517
65386518 device const float * i_ptr = src0 + nc * I_HW;
65396519 device float * o_ptr = dst + nc * O_HW;
65406520
6541- const int start_h = cur_oh * s1 - p1;
6521+ const int start_h = cur_oh * args. s1 - args. p1 ;
65426522 const int bh = MAX (0 , start_h);
6543- const int eh = MIN (IH, start_h + k1);
6544- const int start_w = cur_ow * s0 - p0;
6523+ const int eh = MIN (args. IH , start_h + args. k1 );
6524+ const int start_w = cur_ow * args. s0 - args. p0 ;
65456525 const int bw = MAX (0 , start_w);
6546- const int ew = MIN (IW, start_w + k0);
6526+ const int ew = MIN (args. IW , start_w + args. k0 );
65476527 // const float scale = 1. / ((eh - bh) * (ew - bw));
6548- const float scale = 1 . / (k0 * k1);
6528+ const float scale = 1 . / (args. k0 * args. k1 );
65496529
65506530 float res = 0 ;
65516531
65526532 for (int i = bh; i < eh; i += 1 ) {
65536533 for (int j = bw; j < ew; j += 1 ) {
6554- float cur = i_ptr[i * IW + j];
6534+ float cur = i_ptr[i * args. IW + j];
65556535 res += cur * scale;
65566536 }
65576537 }
65586538
6559- o_ptr[cur_oh * OW + cur_ow] = res;
6539+ o_ptr[cur_oh * args. OW + cur_ow] = res;
65606540}
0 commit comments