Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,33 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cpy(ggml_metal_library_t l
return res;
}

ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_1d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type);

const char * pool_str = "undefined";
switch (op_pool) {
case GGML_OP_POOL_AVG: pool_str = "avg"; break;
case GGML_OP_POOL_MAX: pool_str = "max"; break;
default: GGML_ASSERT(false && "not implemented");
};

char base[256];
char name[256];

snprintf(base, sizeof(base), "kernel_pool_1d_%s_%s", pool_str, ggml_type_name(op->src[0]->type));
snprintf(name, sizeof(name), "%s", base);

ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}

res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);

return res;
}

ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type);
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-device.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l

ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_1d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-metal/ggml-metal-device.m
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_IM2COL:
return ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
case GGML_OP_POOL_1D:
return false;
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_UPSCALE:
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
case GGML_OP_POOL_2D:
Expand Down
10 changes: 10 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,16 @@ typedef struct {
int64_t np;
} ggml_metal_kargs_pool_2d;


typedef struct {
int IW;
int OW;
int np;
int k0;
int s0;
int p0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The order of the arguments does not match the initialization in ggml-metal-ops.cpp.

Also, some of these arguments like np need to be 64-bit integers.

} ggml_metal_kargs_pool_1d;

typedef struct {
int64_t ne00;
uint64_t nb01;
Expand Down
56 changes: 56 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_cpy(ctx, idx);
} break;
case GGML_OP_POOL_1D:
{
n_fuse = ggml_metal_op_pool_1d(ctx, idx);
} break;
case GGML_OP_POOL_2D:
{
n_fuse = ggml_metal_op_pool_2d(ctx, idx);
Expand Down Expand Up @@ -1331,6 +1335,58 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
return 1;
}

int ggml_metal_op_pool_1d(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);

ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;

GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);

const int32_t * opts = op->op_params;
ggml_op_pool op_pool = (ggml_op_pool) opts[0];

const int32_t k0 = opts[1];
const int32_t s0 = opts[2];
const int32_t p0 = opts[3];

const int64_t IW = op->src[0]->ne[0];

// planes are the remaining dims: N * OC * (and OH,OW in 2D; for 1D only OC & N + OW)
const int64_t N = op->ne[3];
const int64_t OC = op->ne[2];
const int64_t OW = op->ne[0];

const int64_t np = N * OC * OW;

ggml_metal_kargs_pool_1d args_pool_1d = {
/* .k0 = */ k0,
/* .s0 = */ s0,
/* .p0 = */ p0,
/* .IW = */ (int) IW,
/* .OW = */ (int) OW,
/* .np = */ (int) np
};

ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pool_1d(lib, op, op_pool);

const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
const int ntg = (np + nth - 1) / nth;

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args_pool_1d, sizeof(args_pool_1d), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);

ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1);

return 1;
}


int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);

Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_pool_1d (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_mul_mat (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_mul_mat_id (ggml_metal_op_t ctx, int idx);
Expand Down
65 changes: 65 additions & 0 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -8720,3 +8720,68 @@ kernel void kernel_pool_2d_avg_f32(

o_ptr[cur_oh * args.OW + cur_ow] = res;
}


kernel void kernel_pool_1d_max_f32(
constant ggml_metal_kargs_pool_1d & args,
device const float * src0,
device float * dst,
uint gid[[thread_position_in_grid]]) {

if (gid >= args.np) {
return;
}

const int idx = gid;
const int O_L = args.OW;
const int nc = idx / O_L;
const int cur_o0 = idx % O_L;

device const float * i_ptr = src0 + nc * args.IW;
device float * o_ptr = dst + nc * O_L;

const int start = cur_o0 * args.s0 - args.p0;
const int b = MAX(0, start);
const int e = MIN(args.IW, start + args.k0);

float res = -INFINITY;

for (int j = b; j < e; ++j) {
res = MAX(res, i_ptr[j]);
}

o_ptr[cur_o0] = res;
}

kernel void kernel_pool_1d_avg_f32(
constant ggml_metal_kargs_pool_1d & args,
device const float * src0,
device float * dst,
uint gid[[thread_position_in_grid]]) {

if (gid >= args.np) {
return;
}

const int idx = gid;
const int O_L = args.OW;
const int nc = idx / O_L;
const int cur_o0 = idx % O_L;

device const float * i_ptr = src0 + nc * args.IW;
device float * o_ptr = dst + nc * O_L;

const int start = cur_o0 * args.s0 - args.p0;
const int b = MAX(0, start);
const int e = MIN(args.IW, start + args.k0);

const float scale = 1.0f / args.k0;

float res = 0.0f;

for (int j = b; j < e; ++j) {
res += i_ptr[j] * scale;
}

o_ptr[cur_o0] = res;
}