Skip to content

Commit 746e79e

Browse files
committed
apply review
Signed-off-by: Junhee Yoo <[email protected]>
1 parent bb9949b commit 746e79e

File tree

2 files changed

+31
-25
lines changed

2 files changed

+31
-25
lines changed

ggml/src/ggml-metal.m

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -854,7 +854,6 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
854854
case GGML_OP_POOL_1D:
855855
return false;
856856
case GGML_OP_POOL_2D:
857-
return true;
858857
case GGML_OP_UPSCALE:
859858
case GGML_OP_PAD:
860859
case GGML_OP_ARANGE:
@@ -2554,6 +2553,8 @@ static void ggml_metal_encode_node(
25542553
} break;
25552554
case GGML_OP_IM2COL:
25562555
{
2556+
GGML_ASSERT(ggml_is_contiguous(src0));
2557+
GGML_ASSERT(ggml_is_contiguous(src1));
25572558
GGML_ASSERT(src0->type == GGML_TYPE_F16);
25582559
GGML_ASSERT(src1->type == GGML_TYPE_F32);
25592560
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
@@ -2620,7 +2621,7 @@ static void ggml_metal_encode_node(
26202621
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
26212622

26222623
if (is_gt_mttpt) {
2623-
[encoder setBytes:&N length:sizeof(int32_t) atIndex:13];
2624+
[encoder setBytes:&N length:sizeof(int32_t) atIndex:13];
26242625
[encoder setBytes:&KH length:sizeof(int32_t) atIndex:14];
26252626
[encoder setBytes:&KW length:sizeof(int32_t) atIndex:15];
26262627

@@ -3034,9 +3035,10 @@ static void ggml_metal_encode_node(
30343035
} break;
30353036
case GGML_OP_POOL_2D:
30363037
{
3038+
GGML_ASSERT(ggml_is_contiguous(src0));
30373039
GGML_ASSERT(src0t == GGML_TYPE_F32 && src0t == dstt);
30383040

3039-
const int32_t* opts = dst->op_params;
3041+
const int32_t * opts = dst->op_params;
30403042
enum ggml_op_pool op = opts[0];
30413043

30423044
id<MTLComputePipelineState> pipeline = nil;
@@ -3063,7 +3065,7 @@ static void ggml_metal_encode_node(
30633065
const int64_t IH = src0->ne[1];
30643066
const int64_t IW = src0->ne[0];
30653067

3066-
const int64_t N = dst->ne[3];
3068+
const int64_t N = dst->ne[3];
30673069
const int64_t OC = dst->ne[2];
30683070
const int64_t OH = dst->ne[1];
30693071
const int64_t OW = dst->ne[0];

ggml/src/ggml-metal.metal

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6479,39 +6479,41 @@ kernel void kernel_pool_2d_max_f32(
64796479
const int cur_oh = idx % O_HW / OW;
64806480
const int cur_ow = idx % O_HW % OW;
64816481

6482-
device const float* i_ptr = src0 + nc * I_HW;
6483-
device float* o_ptr = dst + nc * O_HW;
6482+
device const float * i_ptr = src0 + nc * I_HW;
6483+
device float * o_ptr = dst + nc * O_HW;
64846484

64856485
const int start_h = cur_oh * s1 - p1;
6486-
const int bh = MAX(0, start_h);
6486+
const int bh = MAX(0, start_h);
64876487
const int eh = MIN(IH, start_h + k1);
64886488
const int start_w = cur_ow * s0 - p0;
6489-
const int bw = MAX(0, start_w);
6489+
const int bw = MAX(0, start_w);
64906490
const int ew = MIN(IW, start_w + k0);
6491+
64916492
float res = -INFINITY;
64926493

64936494
for (int i = bh; i < eh; i += 1) {
64946495
for (int j = bw; j < ew; j += 1) {
64956496
res = MAX(res, i_ptr[i * IW + j]);
64966497
}
64976498
}
6499+
64986500
o_ptr[cur_oh * OW + cur_ow] = res;
64996501
}
65006502

65016503
kernel void kernel_pool_2d_avg_f32(
6502-
device const float* src0,
6503-
device float* dst,
6504-
constant int32_t& k0,
6505-
constant int32_t& k1,
6506-
constant int32_t& s0,
6507-
constant int32_t& s1,
6508-
constant int32_t& p0,
6509-
constant int32_t& p1,
6510-
constant int64_t& IH,
6511-
constant int64_t& IW,
6512-
constant int64_t& OH,
6513-
constant int64_t& OW,
6514-
constant int64_t& parallel_elements,
6504+
device const float * src0,
6505+
device float * dst,
6506+
constant int32_t & k0,
6507+
constant int32_t & k1,
6508+
constant int32_t & s0,
6509+
constant int32_t & s1,
6510+
constant int32_t & p0,
6511+
constant int32_t & p1,
6512+
constant int64_t & IH,
6513+
constant int64_t & IW,
6514+
constant int64_t & OH,
6515+
constant int64_t & OW,
6516+
constant int64_t & parallel_elements,
65156517
uint gid[[thread_position_in_grid]]) {
65166518

65176519
if (gid >= parallel_elements) {
@@ -6525,17 +6527,18 @@ kernel void kernel_pool_2d_avg_f32(
65256527
const int cur_oh = idx % O_HW / OW;
65266528
const int cur_ow = idx % O_HW % OW;
65276529

6528-
device const float* i_ptr = src0 + nc * I_HW;
6529-
device float* o_ptr = dst + nc * O_HW;
6530+
device const float * i_ptr = src0 + nc * I_HW;
6531+
device float * o_ptr = dst + nc * O_HW;
65306532

65316533
const int start_h = cur_oh * s1 - p1;
6532-
const int bh = MAX(0, start_h);
6534+
const int bh = MAX(0, start_h);
65336535
const int eh = MIN(IH, start_h + k1);
65346536
const int start_w = cur_ow * s0 - p0;
6535-
const int bw = MAX(0, start_w);
6537+
const int bw = MAX(0, start_w);
65366538
const int ew = MIN(IW, start_w + k0);
65376539
// const float scale = 1. / ((eh - bh) * (ew - bw));
65386540
const float scale = 1. / (k0 * k1);
6541+
65396542
float res = 0;
65406543

65416544
for (int i = bh; i < eh; i += 1) {
@@ -6544,5 +6547,6 @@ kernel void kernel_pool_2d_avg_f32(
65446547
res += cur * scale;
65456548
}
65466549
}
6550+
65476551
o_ptr[cur_oh * OW + cur_ow] = res;
65486552
}

0 commit comments

Comments
 (0)