@@ -394,6 +394,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
394394 {
395395 n_fuse = ggml_metal_op_conv_transpose_2d (ctx, idx);
396396 } break ;
397+ case GGML_OP_CONV_3D:
398+ {
399+ n_fuse = ggml_metal_op_conv_3d (ctx, idx);
400+ } break ;
397401 case GGML_OP_UPSCALE:
398402 {
399403 n_fuse = ggml_metal_op_upscale (ctx, idx);
@@ -3697,6 +3701,77 @@ int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) {
36973701 return 1 ;
36983702}
36993703
3704+ int ggml_metal_op_conv_3d (ggml_metal_op_t ctx, int idx) {
3705+ ggml_tensor * op = ctx->node (idx);
3706+
3707+ ggml_metal_library_t lib = ctx->lib ;
3708+ ggml_metal_encoder_t enc = ctx->enc ;
3709+
3710+ // 1. Extract standard dimensions and byte strides
3711+ GGML_TENSOR_LOCALS (uint64_t , nb0, op->src [0 ], nb);
3712+ GGML_TENSOR_LOCALS (uint64_t , nb1, op->src [1 ], nb);
3713+ GGML_TENSOR_LOCALS (uint64_t , nb, op, nb);
3714+
3715+ // 2. Extract hyperparams from op_params
3716+ const int32_t s0 = ((const int32_t *)(op->op_params ))[0 ];
3717+ const int32_t s1 = ((const int32_t *)(op->op_params ))[1 ];
3718+ const int32_t s2 = ((const int32_t *)(op->op_params ))[2 ];
3719+ const int32_t p0 = ((const int32_t *)(op->op_params ))[3 ];
3720+ const int32_t p1 = ((const int32_t *)(op->op_params ))[4 ];
3721+ const int32_t p2 = ((const int32_t *)(op->op_params ))[5 ];
3722+ const int32_t d0 = ((const int32_t *)(op->op_params ))[6 ];
3723+ const int32_t d1 = ((const int32_t *)(op->op_params ))[7 ];
3724+ const int32_t d2 = ((const int32_t *)(op->op_params ))[8 ];
3725+ const int32_t IC = ((const int32_t *)(op->op_params ))[9 ];
3726+ const int32_t N = ((const int32_t *)(op->op_params ))[10 ];
3727+ const int32_t OC = ((const int32_t *)(op->op_params ))[11 ];
3728+
3729+ // 3. Build the parameter struct using the macro-generated variables
3730+ ggml_metal_kargs_conv_3d args = {
3731+ /* .IW =*/ (int32_t )op->src [1 ]->ne [0 ],
3732+ /* .IH =*/ (int32_t )op->src [1 ]->ne [1 ],
3733+ /* .ID =*/ (int32_t )op->src [1 ]->ne [2 ],
3734+ /* .OW =*/ (int32_t )op->ne [0 ],
3735+ /* .OH =*/ (int32_t )op->ne [1 ],
3736+ /* .OD =*/ (int32_t )op->ne [2 ],
3737+ /* .KW =*/ (int32_t )op->src [0 ]->ne [0 ],
3738+ /* .KH =*/ (int32_t )op->src [0 ]->ne [1 ],
3739+ /* .KD =*/ (int32_t )op->src [0 ]->ne [2 ],
3740+ s0, s1, s2,
3741+ p0, p1, p2,
3742+ d0, d1, d2,
3743+ IC, N, OC,
3744+ nb00, nb01, nb02, nb03, // Weight strides
3745+ nb10, nb11, nb12, nb13, // Input strides
3746+ nb0, nb1, nb2, nb3 // Output strides
3747+ };
3748+
3749+ // 4. Fetch the JIT pipeline
3750+ auto pipeline = ggml_metal_library_get_pipeline_conv_3d (lib, op);
3751+
3752+ // 5. Grid mapping
3753+ int nth0 = 32 ; // Standard SIMD width for Apple Silicon
3754+ int nth1 = 1 ;
3755+ int nth2 = 1 ;
3756+
3757+ int64_t spatial_volume = args.OW * args.OH * args.OD ;
3758+
3759+ int ntg0 = (spatial_volume + nth0 - 1 ) / nth0;
3760+ int ntg1 = args.OC ;
3761+ int ntg2 = args.N ;
3762+
3763+ // 6. Bind and Dispatch via the ggml C wrapper
3764+ ggml_metal_encoder_set_pipeline (enc, pipeline);
3765+ ggml_metal_encoder_set_bytes (enc, &args, sizeof (args), 0 );
3766+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), 1 );
3767+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [1 ]), 2 );
3768+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op), 3 );
3769+
3770+ ggml_metal_encoder_dispatch_threadgroups (enc, ntg0, ntg1, ntg2, nth0, nth1, nth2);
3771+
3772+ return 1 ;
3773+ }
3774+
37003775int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx) {
37013776 ggml_tensor * op = ctx->node (idx);
37023777
0 commit comments