Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1237,7 +1237,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col(ggml_metal_library_
char base[256];
char name[256];

snprintf(base, 256, "kernel_im2col_ext_%s", ggml_type_name(op->type));
snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type));
snprintf(name, 256, "%s", base);

ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
Expand Down
8 changes: 4 additions & 4 deletions ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2768,7 +2768,6 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
const uint64_t ofs0 = op->src[1]->nb[is_2D ? 3 : 2] / 4;
const uint64_t ofs1 = op->src[1]->nb[is_2D ? 2 : 1] / 4;


ggml_metal_kargs_im2col args = {
/*.ofs0 =*/ ofs0,
/*.ofs1 =*/ ofs1,
Expand All @@ -2789,15 +2788,16 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {

ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_im2col(lib, op);

const uint64_t n_threads = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), N);
const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);
GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));

const uint64_t ntptg0 = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)/(KH*KW), N);

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);

ggml_metal_encoder_dispatch_threadgroups(enc, quotient * CHW, OH, OW, n_threads, 1, 1);
ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW);

return 1;
}
Expand Down
164 changes: 88 additions & 76 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -3987,8 +3987,72 @@ template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kerne
template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;

typedef void (im2col_t)(
constant ggml_metal_kargs_im2col & args,
device const float * x,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);

template <typename T>
kernel void kernel_im2col(
constant ggml_metal_kargs_im2col & args,
device const float * x,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
// const int64_t IC = tgpg[0];
const int64_t OH = tgpg[1];
const int64_t OW = tgpg[2];

const int64_t KH = ntg[1];
const int64_t KW = ntg[2];

int64_t in = tpitg[0];
const int64_t ikh = tpitg[1];
const int64_t ikw = tpitg[2];

const int64_t iic = tgpig[0];
const int64_t ioh = tgpig[1];
const int64_t iow = tgpig[2];

const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0;
const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1;

int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);

device T * pdst = (device T *) (dst);

if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
while (in < args.N) {
pdst[offset_dst] = 0.0f;
offset_dst += ntg[0]*args.CHW*OH*OW;

in += ntg[0];
}
} else {
int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw;

while (in < args.N) {
pdst[offset_dst] = x[offset_src];

offset_dst += ntg[0]*args.CHW*OH*OW;
offset_src += ntg[0]*args.ofs0;

in += ntg[0];
}
}
}

template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;

// TODO: obolete -- remove
//typedef void (im2col_t)(
//typedef void (im2col_ext_t)(
// constant ggml_metal_kargs_im2col & args,
// device const float * x,
// device char * dst,
Expand All @@ -3998,100 +4062,48 @@ template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t ker
// uint3 ntg[[threads_per_threadgroup]]);
//
//template <typename T>
//kernel void kernel_im2col(
//kernel void kernel_im2col_ext(
// constant ggml_metal_kargs_im2col & args,
// device const float * x,
// device char * dst,
// uint3 tgpig[[threadgroup_position_in_grid]],
// uint3 tgpg[[threadgroups_per_grid]],
// uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
// uint3 tpitg[[thread_position_in_threadgroup]],
// uint3 ntg[[threads_per_threadgroup]]) {
//// const int64_t IC = tgpg[0];
// const int64_t OH = tgpg[1];
// const int64_t OW = tgpg[2];
// uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
// const int64_t KHW = (int64_t)args.KHW;
//
//// const int64_t N = ntg[0];
// const int64_t KH = ntg[1];
// const int64_t KW = ntg[2];
// const int64_t d = tgpig[0] / args.CHW;
// const int64_t chw = tgpig[0] % args.CHW;
// const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
// const int64_t HW = tgpig[0] % KHW;
//
// const int64_t in = tpitg[0];
// const int64_t ikh = tpitg[1];
// const int64_t ikw = tpitg[2];
// const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
// if (tpitg_0 >= args.N) {
// return;
// }
//
// const int64_t iic = tgpig[0];
// const int64_t ioh = tgpig[1];
// const int64_t iow = tgpig[2];
// const int64_t tpitg_1 = HW / args.KW;
// const int64_t tpitg_2 = HW % args.KW;
//
// const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0;
// const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1;
// const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
// const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
//
// const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);
// const int64_t offset_dst =
// (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
// (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
//
// device T * pdst = (device T *) (dst);
//
// if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
// pdst[offset_dst] = 0.0f;
// } else {
// const int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw;
// pdst[offset_dst] = x[offset_src];
// const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
// pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
// }
//}
//
//template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
//template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;

typedef void (im2col_ext_t)(
constant ggml_metal_kargs_im2col & args,
device const float * x,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);

template <typename T>
kernel void kernel_im2col_ext(
constant ggml_metal_kargs_im2col & args,
device const float * x,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
const int64_t KHW = (int64_t)args.KHW;

const int64_t d = tgpig[0] / args.CHW;
const int64_t chw = tgpig[0] % args.CHW;
const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
const int64_t HW = tgpig[0] % KHW;

const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
if (tpitg_0 >= args.N) {
return;
}

const int64_t tpitg_1 = HW / args.KW;
const int64_t tpitg_2 = HW % args.KW;

const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;

const int64_t offset_dst =
(tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
(tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);

device T * pdst = (device T *) (dst);

if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
pdst[offset_dst] = 0.0f;
} else {
const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
}
}

template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
//template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
//template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;

typedef void (conv_transpose_1d_t)(
constant ggml_metal_kargs_conv_transpose_1d & args,
Expand Down
Loading