Skip to content

Commit b4d3c16

Browse files
committed
add pool_2d
Signed-off-by: Junhee Yoo <[email protected]>
1 parent f010b77 commit b4d3c16

File tree

2 files changed

+158
-1
lines changed

2 files changed

+158
-1
lines changed

ggml/src/ggml-metal.m

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
272272
GGML_METAL_KERNEL_TYPE_SIN,
273273
GGML_METAL_KERNEL_TYPE_COS,
274274
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
275+
GGML_METAL_KERNEL_TYPE_AVG_POOL_2D_F32,
276+
GGML_METAL_KERNEL_TYPE_MAX_POOL_2D_F32,
275277

276278
GGML_METAL_KERNEL_TYPE_COUNT
277279
};
@@ -716,6 +718,8 @@ @implementation GGMLMetalClass
716718
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
717719
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
718720
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
721+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_AVG_POOL_2D_F32, avg_pool_2d_f32, true);
722+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MAX_POOL_2D_F32, max_pool_2d_f32, true);
719723
}
720724

721725
[metal_library release];
@@ -844,8 +848,9 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
844848
case GGML_OP_IM2COL:
845849
return op->src[0]->type == GGML_TYPE_F16;
846850
case GGML_OP_POOL_1D:
847-
case GGML_OP_POOL_2D:
848851
return false;
852+
case GGML_OP_POOL_2D:
853+
return true;
849854
case GGML_OP_UPSCALE:
850855
case GGML_OP_PAD:
851856
case GGML_OP_ARANGE:
@@ -3001,6 +3006,63 @@ static void ggml_metal_encode_node(
30013006

30023007
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
30033008
} break;
3009+
case GGML_OP_POOL_2D:
3010+
{
3011+
GGML_ASSERT(src0t == GGML_TYPE_F32 && src0t == dstt);
3012+
3013+
const int32_t* opts = dst->op_params;
3014+
enum ggml_op_pool op = opts[0];
3015+
3016+
id<MTLComputePipelineState> pipeline = nil;
3017+
switch (src0t) {
3018+
case GGML_TYPE_F32: {
3019+
switch(op) {
3020+
case GGML_OP_POOL_AVG:
3021+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_AVG_POOL_2D_F32].pipeline; break;
3022+
case GGML_OP_POOL_MAX:
3023+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MAX_POOL_2D_F32].pipeline; break;
3024+
default: GGML_ASSERT(false && "not implemented");
3025+
}
3026+
} break;
3027+
default: GGML_ASSERT(false && "not implemented");
3028+
}
3029+
3030+
const int32_t k0 = opts[1];
3031+
const int32_t k1 = opts[2];
3032+
const int32_t s0 = opts[3];
3033+
const int32_t s1 = opts[4];
3034+
const int32_t p0 = opts[5];
3035+
const int32_t p1 = opts[6];
3036+
3037+
const int64_t IH = src0->ne[1];
3038+
const int64_t IW = src0->ne[0];
3039+
3040+
const int64_t N = dst->ne[3];
3041+
const int64_t OC = dst->ne[2];
3042+
const int64_t OH = dst->ne[1];
3043+
const int64_t OW = dst->ne[0];
3044+
3045+
const int64_t parallel_elements = N * OC * OH * OW;
3046+
const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
3047+
const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
3048+
3049+
[encoder setComputePipelineState:pipeline];
3050+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3051+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3052+
[encoder setBytes:&k0 length:sizeof(int32_t) atIndex:2];
3053+
[encoder setBytes:&k1 length:sizeof(int32_t) atIndex:3];
3054+
[encoder setBytes:&s0 length:sizeof(int32_t) atIndex:4];
3055+
[encoder setBytes:&s1 length:sizeof(int32_t) atIndex:5];
3056+
[encoder setBytes:&p0 length:sizeof(int32_t) atIndex:6];
3057+
[encoder setBytes:&p1 length:sizeof(int32_t) atIndex:7];
3058+
[encoder setBytes:&IH length:sizeof(int64_t) atIndex:8];
3059+
[encoder setBytes:&IW length:sizeof(int64_t) atIndex:9];
3060+
[encoder setBytes:&OH length:sizeof(int64_t) atIndex:10];
3061+
[encoder setBytes:&OW length:sizeof(int64_t) atIndex:11];
3062+
[encoder setBytes:&parallel_elements length:sizeof(int64_t) atIndex:12];
3063+
3064+
[encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
3065+
} break;
30043066
default:
30053067
{
30063068
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));

ggml/src/ggml-metal.metal

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6372,3 +6372,98 @@ template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t
63726372
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
63736373
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
63746374
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
6375+
6376+
kernel void kernel_max_pool_2d_f32(
6377+
device const float* src0,
6378+
device float* dst,
6379+
constant int32_t& k0,
6380+
constant int32_t& k1,
6381+
constant int32_t& s0,
6382+
constant int32_t& s1,
6383+
constant int32_t& p0,
6384+
constant int32_t& p1,
6385+
constant int64_t& IH,
6386+
constant int64_t& IW,
6387+
constant int64_t& OH,
6388+
constant int64_t& OW,
6389+
constant int64_t& parallel_elements,
6390+
uint gid[[thread_position_in_grid]]) {
6391+
6392+
if (gid >= parallel_elements) {
6393+
return;
6394+
}
6395+
6396+
const int idx = gid;
6397+
const int I_HW = IH * IW;
6398+
const int O_HW = OH * OW;
6399+
const int nc = idx / O_HW;
6400+
const int cur_oh = idx % O_HW / OW;
6401+
const int cur_ow = idx % O_HW % OW;
6402+
6403+
device const float* i_ptr = src0 + nc * I_HW;
6404+
device float* o_ptr = dst + nc * O_HW;
6405+
6406+
const int start_h = cur_oh * s1 - p1;
6407+
const int bh = MAX(0, start_h);
6408+
const int eh = MIN(IH, start_h + k1);
6409+
const int start_w = cur_ow * s0 - p0;
6410+
const int bw = MAX(0, start_w);
6411+
const int ew = MIN(IW, start_w + k0);
6412+
float res = -INFINITY;
6413+
6414+
for (int i = bh; i < eh; i += 1) {
6415+
for (int j = bw; j < ew; j += 1) {
6416+
res = MAX(res, i_ptr[i * IW + j]);
6417+
}
6418+
}
6419+
o_ptr[cur_oh * OW + cur_ow] = res;
6420+
}
6421+
6422+
kernel void kernel_avg_pool_2d_f32(
6423+
device const float* src0,
6424+
device float* dst,
6425+
constant int32_t& k0,
6426+
constant int32_t& k1,
6427+
constant int32_t& s0,
6428+
constant int32_t& s1,
6429+
constant int32_t& p0,
6430+
constant int32_t& p1,
6431+
constant int64_t& IH,
6432+
constant int64_t& IW,
6433+
constant int64_t& OH,
6434+
constant int64_t& OW,
6435+
constant int64_t& parallel_elements,
6436+
uint gid[[thread_position_in_grid]]) {
6437+
6438+
if (gid >= parallel_elements) {
6439+
return;
6440+
}
6441+
6442+
const int idx = gid;
6443+
const int I_HW = IH * IW;
6444+
const int O_HW = OH * OW;
6445+
const int nc = idx / O_HW;
6446+
const int cur_oh = idx % O_HW / OW;
6447+
const int cur_ow = idx % O_HW % OW;
6448+
6449+
device const float* i_ptr = src0 + nc * I_HW;
6450+
device float* o_ptr = dst + nc * O_HW;
6451+
6452+
const int start_h = cur_oh * s1 - p1;
6453+
const int bh = MAX(0, start_h);
6454+
const int eh = MIN(IH, start_h + k1);
6455+
const int start_w = cur_ow * s0 - p0;
6456+
const int bw = MAX(0, start_w);
6457+
const int ew = MIN(IW, start_w + k0);
6458+
// const float scale = 1. / ((eh - bh) * (ew - bw));
6459+
const float scale = 1. / (k0 * k1);
6460+
float res = 0;
6461+
6462+
for (int i = bh; i < eh; i += 1) {
6463+
for (int j = bw; j < ew; j += 1) {
6464+
float cur = i_ptr[i * IW + j];
6465+
res += cur * scale;
6466+
}
6467+
}
6468+
o_ptr[cur_oh * OW + cur_ow] = res;
6469+
}

0 commit comments

Comments
 (0)