@@ -6479,39 +6479,41 @@ kernel void kernel_pool_2d_max_f32(
64796479 const int cur_oh = idx % O_HW / OW;
64806480 const int cur_ow = idx % O_HW % OW;
64816481
6482- device const float * i_ptr = src0 + nc * I_HW;
6483- device float * o_ptr = dst + nc * O_HW;
6482+ device const float * i_ptr = src0 + nc * I_HW;
6483+ device float * o_ptr = dst + nc * O_HW;
64846484
64856485 const int start_h = cur_oh * s1 - p1;
6486- const int bh = MAX (0 , start_h);
6486+ const int bh = MAX (0 , start_h);
64876487 const int eh = MIN (IH, start_h + k1);
64886488 const int start_w = cur_ow * s0 - p0;
6489- const int bw = MAX (0 , start_w);
6489+ const int bw = MAX (0 , start_w);
64906490 const int ew = MIN (IW, start_w + k0);
6491+
64916492 float res = -INFINITY;
64926493
64936494 for (int i = bh; i < eh; i += 1 ) {
64946495 for (int j = bw; j < ew; j += 1 ) {
64956496 res = MAX (res, i_ptr[i * IW + j]);
64966497 }
64976498 }
6499+
64986500 o_ptr[cur_oh * OW + cur_ow] = res;
64996501}
65006502
65016503kernel void kernel_pool_2d_avg_f32 (
6502- device const float * src0,
6503- device float * dst,
6504- constant int32_t & k0,
6505- constant int32_t & k1,
6506- constant int32_t & s0,
6507- constant int32_t & s1,
6508- constant int32_t & p0,
6509- constant int32_t & p1,
6510- constant int64_t & IH,
6511- constant int64_t & IW,
6512- constant int64_t & OH,
6513- constant int64_t & OW,
6514- constant int64_t & parallel_elements,
6504+ device const float * src0,
6505+ device float * dst,
6506+ constant int32_t & k0,
6507+ constant int32_t & k1,
6508+ constant int32_t & s0,
6509+ constant int32_t & s1,
6510+ constant int32_t & p0,
6511+ constant int32_t & p1,
6512+ constant int64_t & IH,
6513+ constant int64_t & IW,
6514+ constant int64_t & OH,
6515+ constant int64_t & OW,
6516+ constant int64_t & parallel_elements,
65156517 uint gid[[thread_position_in_grid]]) {
65166518
65176519 if (gid >= parallel_elements) {
@@ -6525,17 +6527,18 @@ kernel void kernel_pool_2d_avg_f32(
65256527 const int cur_oh = idx % O_HW / OW;
65266528 const int cur_ow = idx % O_HW % OW;
65276529
6528- device const float * i_ptr = src0 + nc * I_HW;
6529- device float * o_ptr = dst + nc * O_HW;
6530+ device const float * i_ptr = src0 + nc * I_HW;
6531+ device float * o_ptr = dst + nc * O_HW;
65306532
65316533 const int start_h = cur_oh * s1 - p1;
6532- const int bh = MAX (0 , start_h);
6534+ const int bh = MAX (0 , start_h);
65336535 const int eh = MIN (IH, start_h + k1);
65346536 const int start_w = cur_ow * s0 - p0;
6535- const int bw = MAX (0 , start_w);
6537+ const int bw = MAX (0 , start_w);
65366538 const int ew = MIN (IW, start_w + k0);
65376539 // const float scale = 1. / ((eh - bh) * (ew - bw));
65386540 const float scale = 1 . / (k0 * k1);
6541+
65396542 float res = 0 ;
65406543
65416544 for (int i = bh; i < eh; i += 1 ) {
@@ -6544,5 +6547,6 @@ kernel void kernel_pool_2d_avg_f32(
65446547 res += cur * scale;
65456548 }
65466549 }
6550+
65476551 o_ptr[cur_oh * OW + cur_ow] = res;
65486552}
0 commit comments