Skip to content

Commit 177c758

Browse files
metal: add CONV_3D (#19927)
* Apply suggestions from code review Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * metal:add conv_3d backend Rebased with master and resolved conflicts. * Resolved issues related to changes in variable names * kernel void kernel_upscale_bilinear_f32 was missing in my branch, added back, should pass all tests now --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 7a0b6a6 commit 177c758

File tree

7 files changed

+232
-0
lines changed

7 files changed

+232
-0
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1748,6 +1748,28 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d(ggml_met
17481748
return res;
17491749
}
17501750

1751+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_3d(ggml_metal_library_t lib, const ggml_tensor * op) {
1752+
assert(op->op == GGML_OP_CONV_3D);
1753+
1754+
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
1755+
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
1756+
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
1757+
GGML_ASSERT(op->type == GGML_TYPE_F32);
1758+
1759+
char base[256];
1760+
char name[256];
1761+
1762+
snprintf(base, 256, "kernel_conv_3d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
1763+
snprintf(name, 256, "%s", base);
1764+
1765+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1766+
if (!res.pipeline) {
1767+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1768+
}
1769+
1770+
return res;
1771+
}
1772+
17511773
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) {
17521774
assert(op->op == GGML_OP_UPSCALE);
17531775

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col
148148
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
149149
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
150150
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
151+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_3d (ggml_metal_library_t lib, const struct ggml_tensor * op);
151152
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op);
152153
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op);
153154
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,6 +1077,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
10771077
(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) &&
10781078
op->src[1]->type == GGML_TYPE_F32 &&
10791079
op->type == GGML_TYPE_F32;
1080+
case GGML_OP_CONV_3D:
1081+
return ggml_is_contiguous(op->src[0]) &&
1082+
ggml_is_contiguous(op->src[1]) &&
1083+
(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) &&
1084+
op->src[1]->type == GGML_TYPE_F32;
10801085
case GGML_OP_SUM:
10811086
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
10821087
case GGML_OP_TRI:

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,42 @@ typedef struct {
643643
int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
644644
} ggml_metal_kargs_im2col;
645645

646+
typedef struct {
647+
int32_t IW;
648+
int32_t IH;
649+
int32_t ID;
650+
int32_t OW;
651+
int32_t OH;
652+
int32_t OD;
653+
int32_t KW;
654+
int32_t KH;
655+
int32_t KD;
656+
int32_t s0;
657+
int32_t s1;
658+
int32_t s2;
659+
int32_t p0;
660+
int32_t p1;
661+
int32_t p2;
662+
int32_t d0;
663+
int32_t d1;
664+
int32_t d2;
665+
int32_t IC;
666+
int32_t N;
667+
int32_t OC;
668+
uint64_t nb00;
669+
uint64_t nb01;
670+
uint64_t nb02;
671+
uint64_t nb03;
672+
uint64_t nb10;
673+
uint64_t nb11;
674+
uint64_t nb12;
675+
uint64_t nb13;
676+
uint64_t nb0;
677+
uint64_t nb1;
678+
uint64_t nb2;
679+
uint64_t nb3;
680+
} ggml_metal_kargs_conv_3d;
681+
646682
typedef struct{
647683
int32_t ne00;
648684
uint64_t nb01;

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
394394
{
395395
n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx);
396396
} break;
397+
case GGML_OP_CONV_3D:
398+
{
399+
n_fuse = ggml_metal_op_conv_3d(ctx, idx);
400+
} break;
397401
case GGML_OP_UPSCALE:
398402
{
399403
n_fuse = ggml_metal_op_upscale(ctx, idx);
@@ -3697,6 +3701,77 @@ int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) {
36973701
return 1;
36983702
}
36993703

3704+
int ggml_metal_op_conv_3d(ggml_metal_op_t ctx, int idx) {
3705+
ggml_tensor * op = ctx->node(idx);
3706+
3707+
ggml_metal_library_t lib = ctx->lib;
3708+
ggml_metal_encoder_t enc = ctx->enc;
3709+
3710+
// 1. Extract standard dimensions and byte strides
3711+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3712+
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
3713+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3714+
3715+
// 2. Extract hyperparams from op_params
3716+
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
3717+
const int32_t s1 = ((const int32_t *)(op->op_params))[1];
3718+
const int32_t s2 = ((const int32_t *)(op->op_params))[2];
3719+
const int32_t p0 = ((const int32_t *)(op->op_params))[3];
3720+
const int32_t p1 = ((const int32_t *)(op->op_params))[4];
3721+
const int32_t p2 = ((const int32_t *)(op->op_params))[5];
3722+
const int32_t d0 = ((const int32_t *)(op->op_params))[6];
3723+
const int32_t d1 = ((const int32_t *)(op->op_params))[7];
3724+
const int32_t d2 = ((const int32_t *)(op->op_params))[8];
3725+
const int32_t IC = ((const int32_t *)(op->op_params))[9];
3726+
const int32_t N = ((const int32_t *)(op->op_params))[10];
3727+
const int32_t OC = ((const int32_t *)(op->op_params))[11];
3728+
3729+
// 3. Build the parameter struct using the macro-generated variables
3730+
ggml_metal_kargs_conv_3d args = {
3731+
/*.IW =*/ (int32_t)op->src[1]->ne[0],
3732+
/*.IH =*/ (int32_t)op->src[1]->ne[1],
3733+
/*.ID =*/ (int32_t)op->src[1]->ne[2],
3734+
/*.OW =*/ (int32_t)op->ne[0],
3735+
/*.OH =*/ (int32_t)op->ne[1],
3736+
/*.OD =*/ (int32_t)op->ne[2],
3737+
/*.KW =*/ (int32_t)op->src[0]->ne[0],
3738+
/*.KH =*/ (int32_t)op->src[0]->ne[1],
3739+
/*.KD =*/ (int32_t)op->src[0]->ne[2],
3740+
s0, s1, s2,
3741+
p0, p1, p2,
3742+
d0, d1, d2,
3743+
IC, N, OC,
3744+
nb00, nb01, nb02, nb03, // Weight strides
3745+
nb10, nb11, nb12, nb13, // Input strides
3746+
nb0, nb1, nb2, nb3 // Output strides
3747+
};
3748+
3749+
// 4. Fetch the JIT pipeline
3750+
auto pipeline = ggml_metal_library_get_pipeline_conv_3d(lib, op);
3751+
3752+
// 5. Grid mapping
3753+
int nth0 = 32; // Standard SIMD width for Apple Silicon
3754+
int nth1 = 1;
3755+
int nth2 = 1;
3756+
3757+
int64_t spatial_volume = args.OW * args.OH * args.OD;
3758+
3759+
int ntg0 = (spatial_volume + nth0 - 1) / nth0;
3760+
int ntg1 = args.OC;
3761+
int ntg2 = args.N;
3762+
3763+
// 6. Bind and Dispatch via the ggml C wrapper
3764+
ggml_metal_encoder_set_pipeline(enc, pipeline);
3765+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3766+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3767+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
3768+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
3769+
3770+
ggml_metal_encoder_dispatch_threadgroups(enc, ntg0, ntg1, ntg2, nth0, nth1, nth2);
3771+
3772+
return 1;
3773+
}
3774+
37003775
int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
37013776
ggml_tensor * op = ctx->node(idx);
37023777

ggml/src/ggml-metal/ggml-metal-ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
7575
int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);
7676
int ggml_metal_op_im2col (ggml_metal_op_t ctx, int idx);
7777
int ggml_metal_op_conv_2d (ggml_metal_op_t ctx, int idx);
78+
int ggml_metal_op_conv_3d (ggml_metal_op_t ctx, int idx);
7879
int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx);
7980
int ggml_metal_op_conv_transpose_2d (ggml_metal_op_t ctx, int idx);
8081
int ggml_metal_op_upscale (ggml_metal_op_t ctx, int idx);

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4883,6 +4883,98 @@ kernel void kernel_upscale_bilinear_f32(
48834883
}
48844884
}
48854885

4886+
template <typename T>
4887+
kernel void kernel_conv_3d(
4888+
constant ggml_metal_kargs_conv_3d & args,
4889+
device const char * src0, // Weights [IC * OC, KD, KH, KW]
4890+
device const char * src1, // Inputs [IC * N, ID, IH, IW]
4891+
device char * dst, // Outputs [OC * N, OD, OH, OW]
4892+
uint3 tgpig[[threadgroup_position_in_grid]],
4893+
uint3 tpitg[[thread_position_in_threadgroup]]) {
4894+
4895+
// 1. Un-flatten the spatial dimension from Grid X
4896+
int64_t spatial_idx = tgpig.x * 32 + tpitg.x;
4897+
4898+
if (spatial_idx >= args.OW * args.OH * args.OD) {
4899+
return; // Thread falls outside the spatial volume
4900+
}
4901+
4902+
int64_t od = spatial_idx / (args.OW * args.OH);
4903+
int64_t oh = (spatial_idx / args.OW) % args.OH;
4904+
int64_t ow = spatial_idx % args.OW;
4905+
4906+
// 2. Map Y to Channels, Z to Batch
4907+
int64_t oc = tgpig.y;
4908+
int64_t batch_idx = tgpig.z;
4909+
4910+
// 3. Calculate anchor coordinates in the Input volume
4911+
int64_t i_w_base = ow * args.s0 - args.p0;
4912+
int64_t i_h_base = oh * args.s1 - args.p1;
4913+
int64_t i_d_base = od * args.s2 - args.p2;
4914+
4915+
float sum = 0.0f;
4916+
4917+
// 4. Gather Loop (Iterate over Input Channels -> Depth -> Height -> Width)
4918+
for (int64_t ic = 0; ic < args.IC; ++ic) {
4919+
4920+
// ggml packs batch and channel together in the 4th dimension
4921+
int64_t src_cn_idx = batch_idx * args.IC + ic;
4922+
int64_t w_cn_idx = oc * args.IC + ic;
4923+
4924+
for (int64_t kz = 0; kz < args.KD; ++kz) {
4925+
int64_t id = i_d_base + kz * args.d2;
4926+
if (id < 0 || id >= args.ID) continue; // Boundary check (Padding)
4927+
4928+
for (int64_t ky = 0; ky < args.KH; ++ky) {
4929+
int64_t ih = i_h_base + ky * args.d1;
4930+
if (ih < 0 || ih >= args.IH) continue;
4931+
4932+
for (int64_t kx = 0; kx < args.KW; ++kx) {
4933+
int64_t iw = i_w_base + kx * args.d0;
4934+
if (iw < 0 || iw >= args.IW) continue;
4935+
4936+
// Convert multi-dimensional coordinates to flat byte offsets
4937+
int64_t w_idx = kx*args.nb00 + ky*args.nb01 + kz*args.nb02 + w_cn_idx*args.nb03;
4938+
int64_t i_idx = iw*args.nb10 + ih*args.nb11 + id*args.nb12 + src_cn_idx*args.nb13;
4939+
4940+
// Dereference memory and cast weights to f32 if they were f16
4941+
float w_val = (float)*(device const T*)((device const char*)src0 + w_idx);
4942+
float i_val = *(device const float*)((device const char*)src1 + i_idx);
4943+
4944+
sum += w_val * i_val;
4945+
}
4946+
}
4947+
}
4948+
}
4949+
4950+
// 5. Write the accumulated value out to RAM
4951+
int64_t dst_cn_idx = batch_idx * args.OC + oc;
4952+
int64_t d_idx = ow*args.nb0 + oh*args.nb1 + od*args.nb2 + dst_cn_idx*args.nb3;
4953+
4954+
*(device float*)(dst + d_idx) = sum;
4955+
}
4956+
4957+
// Explicit instantiations so the JIT compiler can find them by name
4958+
template [[host_name("kernel_conv_3d_f32_f32")]]
4959+
kernel void kernel_conv_3d<float>(
4960+
constant ggml_metal_kargs_conv_3d & args,
4961+
device const char * src0,
4962+
device const char * src1,
4963+
device char * dst,
4964+
uint3 tgpig[[threadgroup_position_in_grid]],
4965+
uint3 tpitg[[thread_position_in_threadgroup]]);
4966+
4967+
// Explicit instantiation for f16 weights
4968+
template [[host_name("kernel_conv_3d_f16_f32")]]
4969+
kernel void kernel_conv_3d<half>(
4970+
constant ggml_metal_kargs_conv_3d & args,
4971+
device const char * src0,
4972+
device const char * src1,
4973+
device char * dst,
4974+
uint3 tgpig[[threadgroup_position_in_grid]],
4975+
uint3 tpitg[[thread_position_in_threadgroup]]);
4976+
4977+
48864978
static inline float bicubic_weight1(float x) {
48874979
const float a = -0.75f;
48884980
return ((a + 2) * x - (a + 3)) * x * x + 1;

0 commit comments

Comments
 (0)