diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 9f91662cbd876..8aeefd2c6830f 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1237,7 +1237,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col(ggml_metal_library_ char base[256]; char name[256]; - snprintf(base, 256, "kernel_im2col_ext_%s", ggml_type_name(op->type)); + snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type)); snprintf(name, 256, "%s", base); ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 3b163d9a38e75..15cea72513536 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -2768,7 +2768,6 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) { const uint64_t ofs0 = op->src[1]->nb[is_2D ? 3 : 2] / 4; const uint64_t ofs1 = op->src[1]->nb[is_2D ? 2 : 1] / 4; - ggml_metal_kargs_im2col args = { /*.ofs0 =*/ ofs0, /*.ofs1 =*/ ofs1, @@ -2789,15 +2788,16 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) { ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_im2col(lib, op); - const uint64_t n_threads = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), N); - const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0); + GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + const uint64_t ntptg0 = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)/(KH*KW), N); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - ggml_metal_encoder_dispatch_threadgroups(enc, quotient * CHW, OH, OW, n_threads, 1, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW); return 1; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 2ba4cb50b9ec2..339cbf91fbe57 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3987,8 +3987,72 @@ template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kerne template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision; template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision; +typedef void (im2col_t)( + constant ggml_metal_kargs_im2col & args, + device const float * x, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]); + +template +kernel void kernel_im2col( + constant ggml_metal_kargs_im2col & args, + device const float * x, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { +// const int64_t IC = tgpg[0]; + const int64_t OH = tgpg[1]; + const int64_t OW = tgpg[2]; + + const int64_t KH = ntg[1]; + const int64_t KW = ntg[2]; + + int64_t in = tpitg[0]; + const int64_t ikh = tpitg[1]; + const int64_t ikw = tpitg[2]; + + const int64_t iic = tgpig[0]; + const int64_t ioh = tgpig[1]; + const int64_t iow = tgpig[2]; + + const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0; + const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1; + + int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw); + + device T * pdst = (device T *) (dst); + + if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { + while (in < args.N) { + pdst[offset_dst] = 0.0f; + offset_dst += ntg[0]*args.CHW*OH*OW; + + in += ntg[0]; + } + } else { + int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw; + + while (in < args.N) { + pdst[offset_dst] = x[offset_src]; + + offset_dst += ntg[0]*args.CHW*OH*OW; + offset_src += ntg[0]*args.ofs0; + + in += ntg[0]; + } + } +} + +template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col; +template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col; + // TODO: obolete -- remove -//typedef void (im2col_t)( +//typedef void (im2col_ext_t)( // constant ggml_metal_kargs_im2col & args, // device const float * x, // device char * dst, @@ -3998,100 +4062,48 @@ template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t ker // uint3 ntg[[threads_per_threadgroup]]); // //template -//kernel void kernel_im2col( +//kernel void kernel_im2col_ext( // constant ggml_metal_kargs_im2col & args, // device const float * x, // device char * dst, // uint3 tgpig[[threadgroup_position_in_grid]], -// uint3 tgpg[[threadgroups_per_grid]], +// uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW // uint3 tpitg[[thread_position_in_threadgroup]], -// uint3 ntg[[threads_per_threadgroup]]) { -//// const int64_t IC = tgpg[0]; -// const int64_t OH = tgpg[1]; -// const int64_t OW = tgpg[2]; +// uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1] +// const int64_t KHW = (int64_t)args.KHW; // -//// const int64_t N = ntg[0]; -// const int64_t KH = ntg[1]; -// const int64_t KW = ntg[2]; +// const int64_t d = tgpig[0] / args.CHW; +// const int64_t chw = tgpig[0] % args.CHW; +// const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1) +// const int64_t HW = tgpig[0] % KHW; // -// const int64_t in = tpitg[0]; -// const int64_t ikh = tpitg[1]; -// const int64_t ikw = tpitg[2]; +// const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0]; +// if (tpitg_0 >= args.N) { +// return; +// } // -// const int64_t iic = tgpig[0]; -// const int64_t ioh = tgpig[1]; -// const int64_t iow = tgpig[2]; +// const int64_t tpitg_1 = HW / args.KW; +// const int64_t tpitg_2 = HW % args.KW; // -// const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0; -// const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1; +// const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0; +// const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1; // -// const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw); +// const int64_t offset_dst = +// (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW + +// (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2); // // device T * pdst = (device T *) (dst); // // if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { // pdst[offset_dst] = 0.0f; // } else { -// const int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw; -// pdst[offset_dst] = x[offset_src]; +// const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1; +// pdst[offset_dst] = x[offset_src + iih * args.IW + iiw]; // } //} // -//template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col; -//template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col; - -typedef void (im2col_ext_t)( - constant ggml_metal_kargs_im2col & args, - device const float * x, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tgpg[[threadgroups_per_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]); - -template -kernel void kernel_im2col_ext( - constant ggml_metal_kargs_im2col & args, - device const float * x, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1] - const int64_t KHW = (int64_t)args.KHW; - - const int64_t d = tgpig[0] / args.CHW; - const int64_t chw = tgpig[0] % args.CHW; - const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1) - const int64_t HW = tgpig[0] % KHW; - - const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0]; - if (tpitg_0 >= args.N) { - return; - } - - const int64_t tpitg_1 = HW / args.KW; - const int64_t tpitg_2 = HW % args.KW; - - const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0; - const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1; - - const int64_t offset_dst = - (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW + - (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2); - - device T * pdst = (device T *) (dst); - - if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { - pdst[offset_dst] = 0.0f; - } else { - const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1; - pdst[offset_dst] = x[offset_src + iih * args.IW + iiw]; - } -} - -template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext; -template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext; +//template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext; +//template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext; typedef void (conv_transpose_1d_t)( constant ggml_metal_kargs_conv_transpose_1d & args,