@@ -574,6 +574,8 @@ struct vk_device_struct {
574574 vk_pipeline pipeline_opt_step_sgd_f32;
575575 vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
576576 vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
577+ vk_pipeline pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT];
578+ vk_pipeline pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT];
577579 vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
578580 vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
579581
@@ -1117,6 +1119,56 @@ template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) {
11171119 init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
11181120}
11191121
1122+ struct vk_op_conv_transpose_2d_push_constants {
1123+ uint32_t Cout;
1124+ uint32_t Cin;
1125+ uint32_t N;
1126+
1127+ uint32_t KW;
1128+ uint32_t KH;
1129+ uint32_t W;
1130+ uint32_t H;
1131+ uint32_t OW;
1132+ uint32_t OH;
1133+
1134+ uint32_t s0;
1135+ uint32_t s1;
1136+ uint32_t p0;
1137+ uint32_t p1;
1138+ uint32_t d0;
1139+ uint32_t d1;
1140+
1141+ uint32_t nb01;
1142+ uint32_t nb02;
1143+ uint32_t nb03;
1144+
1145+ uint32_t nb11;
1146+ uint32_t nb12;
1147+ uint32_t nb13;
1148+
1149+ uint32_t nb1;
1150+ uint32_t nb2;
1151+ uint32_t nb3;
1152+
1153+ // init_fastdiv_values constants for dividing by KW, KW*KH, OW, OW*OH, s0, s1
1154+ uint32_t KWmp; uint32_t KWL;
1155+ uint32_t KWKHmp; uint32_t KWKHL;
1156+ uint32_t OWmp; uint32_t OWL;
1157+ uint32_t OWOHmp; uint32_t OWOHL;
1158+ uint32_t s0mp; uint32_t s0L;
1159+ uint32_t s1mp; uint32_t s1L;
1160+ };
1161+
1162+ template <> void init_pushconst_fastdiv(vk_op_conv_transpose_2d_push_constants &p) {
1163+ // Compute magic values to divide by KW, KW*KH, OW, OW*OH, s0, s1
1164+ init_fastdiv_values(p.KW, p.KWmp, p.KWL);
1165+ init_fastdiv_values(p.KW*p.KH, p.KWKHmp, p.KWKHL);
1166+ init_fastdiv_values(p.OW, p.OWmp, p.OWL);
1167+ init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
1168+ init_fastdiv_values(p.s0, p.s0mp, p.s0L);
1169+ init_fastdiv_values(p.s1, p.s1mp, p.s1L);
1170+ }
1171+
11201172struct vk_op_conv2d_dw_push_constants {
11211173 uint32_t ne;
11221174 uint32_t batches;
@@ -1322,7 +1374,7 @@ class vk_perf_logger {
13221374 flops[name].push_back(m * n * (k + (k - 1)) * batch);
13231375 return;
13241376 }
1325- if (node->op == GGML_OP_CONV_2D) {
1377+ if (node->op == GGML_OP_CONV_2D || node->op == GGML_OP_CONV_TRANSPOSE_2D ) {
13261378 std::string name = ggml_op_name(node->op);
13271379 ggml_tensor * knl = node->src[0];
13281380 uint64_t OW = node->ne[0];
@@ -1331,7 +1383,7 @@ class vk_perf_logger {
13311383 uint64_t Cout = node->ne[2];
13321384 uint64_t KW = knl->ne[0];
13331385 uint64_t KH = knl->ne[1];
1334- uint64_t Cin = knl ->ne[2];
1386+ uint64_t Cin = node->src[1] ->ne[2];
13351387 // KxCRS @ CRSxNPQ = KxNPQ -> M=K, K=CRS, N=NPQ
13361388 uint64_t size_M = Cout;
13371389 uint64_t size_K = Cin * KW * KH;
@@ -3492,7 +3544,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
34923544
34933545 ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
34943546
3495- // conv2d
3547+ // conv2d, conv_transpose_2d
34963548 for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
34973549 uint32_t conv2d_WG_SIZE = 256;
34983550 uint32_t conv2d_BS_K = 128;
@@ -3567,31 +3619,30 @@ static void ggml_vk_load_shaders(vk_device& device) {
35673619 std::array<uint32_t, 3> wg_denoms = { conv2d_BS_K, conv2d_BS_NPQ, 1 };
35683620 std::vector<uint32_t> spec_constants = { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD };
35693621
3622+ #define CREATE_CONV(name, type_suffix, spv_suffix) \
3623+ ggml_vk_create_pipeline( \
3624+ device, device->pipeline_##name##type_suffix[s], #name #type_suffix, \
3625+ name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \
3626+ sizeof(vk_op_##name##_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3627+ #define CREATE_CONVS(spv_suffix) \
3628+ CREATE_CONV(conv2d, _f32, spv_suffix) \
3629+ CREATE_CONV(conv2d, _f16_f32, spv_suffix) \
3630+ if (device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_conv_transpose_2d_push_constants)) { \
3631+ CREATE_CONV(conv_transpose_2d, _f32, spv_suffix) \
3632+ CREATE_CONV(conv_transpose_2d, _f16_f32, spv_suffix) \
3633+ }
35703634#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
35713635 if (device->coopmat2) {
3572- ggml_vk_create_pipeline(
3573- device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_cm2_len, conv2d_f32_cm2_data, "main", 3,
3574- sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3575- ggml_vk_create_pipeline(
3576- device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_cm2_len, conv2d_f16_f32_cm2_data, "main", 3,
3577- sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3636+ CREATE_CONVS(_cm2)
35783637 } else
35793638#endif
35803639 if (conv2d_UNROLL) {
3581- ggml_vk_create_pipeline(
3582- device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_unroll_len, conv2d_f32_unroll_data, "main", 3,
3583- sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3584- ggml_vk_create_pipeline(
3585- device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_unroll_len, conv2d_f16_f32_unroll_data, "main", 3,
3586- sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3640+ CREATE_CONVS(_unroll)
35873641 } else {
3588- ggml_vk_create_pipeline(
3589- device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
3590- sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3591- ggml_vk_create_pipeline(
3592- device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3,
3593- sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3642+ CREATE_CONVS( )
35943643 }
3644+ #undef CREATE_CONV
3645+ #undef CREATE_CONVS
35953646 }
35963647
35973648 ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
@@ -7548,6 +7599,33 @@ static std::array<uint32_t, 3> ggml_vk_get_conv_elements(const ggml_tensor *dst)
75487599 return elements;
75497600}
75507601
7602+ static std::array<uint32_t, 3> ggml_vk_get_conv_transpose_2d_elements(const ggml_tensor *dst) {
7603+ const ggml_tensor *src0 = dst->src[0];
7604+ const ggml_tensor *src1 = dst->src[1];
7605+
7606+ // src0 - kernel: [KW, KH, Cout, Cin]
7607+ // src1 - input: [W, H, Cin, N]
7608+ // dst - result: [OW, OH, Cout, N]
7609+
7610+ auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
7611+ return (ins - 1) * s - 2 * p + (ks - 1) * d + 1;
7612+ };
7613+ // parallelize in {OW/BS_K, OH/BS_NPQ, 1}
7614+ int64_t W = src1->ne[0];
7615+ int64_t H = src1->ne[1];
7616+ int64_t KW = src0->ne[0];
7617+ int64_t KH = src0->ne[1];
7618+ int64_t Cout = src0->ne[2];
7619+ int64_t N = src1->ne[3];
7620+ int64_t OH = calc_conv_output_size(H, KH, dst->op_params[0], 0, 1);
7621+ int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], 0, 1);
7622+ int64_t NPQ = N * OW * OH;
7623+
7624+ // Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups
7625+ std::array<uint32_t, 3> elements = { static_cast<uint32_t>(Cout), static_cast<uint32_t>(NPQ), 1 };
7626+ return elements;
7627+ }
7628+
75517629static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * dst, ggml_op op) {
75527630 switch (op) {
75537631 case GGML_OP_GET_ROWS:
@@ -7925,9 +8003,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
79258003 }
79268004 return nullptr;
79278005 case GGML_OP_CONV_2D:
8006+ case GGML_OP_CONV_TRANSPOSE_2D:
79288007 if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
79298008 ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
7930- auto elements = ggml_vk_get_conv_elements(dst);
8009+ std::array<uint32_t, 3> elements;
8010+ if (op == GGML_OP_CONV_2D) elements = ggml_vk_get_conv_elements(dst);
8011+ else if (op == GGML_OP_CONV_TRANSPOSE_2D) elements = ggml_vk_get_conv_transpose_2d_elements(dst);
79318012 vk_conv_shapes shape;
79328013
79338014 uint32_t tiles[CONV_SHAPE_COUNT];
@@ -7947,10 +8028,18 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
79478028 shape = CONV_SHAPE_64x32;
79488029 }
79498030
7950- if (src0->type == GGML_TYPE_F32) {
7951- return ctx->device->pipeline_conv2d_f32[shape];
7952- } else if (src0->type == GGML_TYPE_F16) {
7953- return ctx->device->pipeline_conv2d_f16_f32[shape];
8031+ if (op == GGML_OP_CONV_2D) {
8032+ if (src0->type == GGML_TYPE_F32) {
8033+ return ctx->device->pipeline_conv2d_f32[shape];
8034+ } else if (src0->type == GGML_TYPE_F16) {
8035+ return ctx->device->pipeline_conv2d_f16_f32[shape];
8036+ }
8037+ } else if (op == GGML_OP_CONV_TRANSPOSE_2D) {
8038+ if (src0->type == GGML_TYPE_F32) {
8039+ return ctx->device->pipeline_conv_transpose_2d_f32[shape];
8040+ } else if (src0->type == GGML_TYPE_F16) {
8041+ return ctx->device->pipeline_conv_transpose_2d_f16_f32[shape];
8042+ }
79548043 }
79558044 }
79568045 return nullptr;
@@ -8350,6 +8439,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
83508439 {
83518440 elements = ggml_vk_get_conv_elements(dst);
83528441 } break;
8442+ case GGML_OP_CONV_TRANSPOSE_2D:
8443+ {
8444+ elements = ggml_vk_get_conv_transpose_2d_elements(dst);
8445+ } break;
83538446 case GGML_OP_ADD:
83548447 case GGML_OP_SUB:
83558448 case GGML_OP_DIV:
@@ -9523,6 +9616,55 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx,
95239616 ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D, std::move(p), dryrun);
95249617}
95259618
9619+ static void ggml_vk_conv_transpose_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
9620+ const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
9621+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
9622+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
9623+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
9624+
9625+ GGML_TENSOR_BINARY_OP_LOCALS
9626+
9627+ GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t));
9628+ GGML_ASSERT(nb10 == sizeof(float));
9629+ GGML_ASSERT(nb0 == sizeof(float));
9630+
9631+ vk_op_conv_transpose_2d_push_constants p{};
9632+ p.Cout = static_cast<uint32_t>(ne02);
9633+ p.Cin = static_cast<uint32_t>(ne03);
9634+ p.N = static_cast<uint32_t>(ne13);
9635+
9636+ p.KW = static_cast<uint32_t>(ne00);
9637+ p.KH = static_cast<uint32_t>(ne01);
9638+ p.W = static_cast<uint32_t>(ne10);
9639+ p.H = static_cast<uint32_t>(ne11);
9640+ p.OW = static_cast<uint32_t>(ne0);
9641+ p.OH = static_cast<uint32_t>(ne1);
9642+
9643+ p.s0 = static_cast<uint32_t>(dst->op_params[0]);
9644+ p.s1 = static_cast<uint32_t>(dst->op_params[0]);
9645+ p.p0 = 0;
9646+ p.p1 = 0;
9647+ p.d0 = 1;
9648+ p.d1 = 1;
9649+
9650+ p.nb01 = static_cast<uint32_t>(nb01 / nb00);
9651+ p.nb02 = static_cast<uint32_t>(nb02 / nb00);
9652+ p.nb03 = static_cast<uint32_t>(nb03 / nb00);
9653+
9654+ p.nb11 = static_cast<uint32_t>(nb11 / nb10);
9655+ p.nb12 = static_cast<uint32_t>(nb12 / nb10);
9656+ p.nb13 = static_cast<uint32_t>(nb13 / nb10);
9657+
9658+ p.nb1 = static_cast<uint32_t>(nb1 / nb0);
9659+ p.nb2 = static_cast<uint32_t>(nb2 / nb0);
9660+ p.nb3 = static_cast<uint32_t>(nb3 / nb0);
9661+
9662+ GGML_ASSERT(ne02 == ne2);
9663+ GGML_ASSERT(ne03 == ne12);
9664+
9665+ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_TRANSPOSE_2D, std::move(p), dryrun);
9666+ }
9667+
95269668static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
95279669 vk_op_conv2d_dw_push_constants p{};
95289670 p.ne = ggml_nelements(dst);
@@ -10615,6 +10757,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1061510757 case GGML_OP_CONV_TRANSPOSE_1D:
1061610758 case GGML_OP_POOL_2D:
1061710759 case GGML_OP_CONV_2D:
10760+ case GGML_OP_CONV_TRANSPOSE_2D:
1061810761 case GGML_OP_CONV_2D_DW:
1061910762 case GGML_OP_RWKV_WKV6:
1062010763 case GGML_OP_RWKV_WKV7:
@@ -10686,6 +10829,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1068610829 case GGML_OP_CONV_TRANSPOSE_1D:
1068710830 case GGML_OP_POOL_2D:
1068810831 case GGML_OP_CONV_2D:
10832+ case GGML_OP_CONV_TRANSPOSE_2D:
1068910833 case GGML_OP_CONV_2D_DW:
1069010834 case GGML_OP_LEAKY_RELU:
1069110835 case GGML_OP_OPT_STEP_SGD:
@@ -10997,6 +11141,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1099711141 case GGML_OP_CONV_2D:
1099811142 ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node, dryrun);
1099911143
11144+ break;
11145+ case GGML_OP_CONV_TRANSPOSE_2D:
11146+ ggml_vk_conv_transpose_2d(ctx, compute_ctx, src0, src1, node, dryrun);
11147+
1100011148 break;
1100111149 case GGML_OP_CONV_2D_DW:
1100211150 ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -11137,6 +11285,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1113711285 case GGML_OP_CONV_TRANSPOSE_1D:
1113811286 case GGML_OP_POOL_2D:
1113911287 case GGML_OP_CONV_2D:
11288+ case GGML_OP_CONV_TRANSPOSE_2D:
1114011289 case GGML_OP_CONV_2D_DW:
1114111290 case GGML_OP_RWKV_WKV6:
1114211291 case GGML_OP_RWKV_WKV7:
@@ -11794,10 +11943,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1179411943 ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
1179511944 if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
1179611945 total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
11797- } else if (cgraph->nodes[i]->op == GGML_OP_CONV_2D) {
11946+ } else if (cgraph->nodes[i]->op == GGML_OP_CONV_2D || cgraph->nodes[i]->op == GGML_OP_CONV_TRANSPOSE_2D ) {
1179811947 // Return CRSxNPQxsizeof(*) to account as many bytes as mul_mat has in im2col->mul_mat mode.
1179911948 auto CRS_size =
11800- cgraph->nodes[i]->src[0]->ne[0] * cgraph->nodes[i]->src[0]->ne[1] * cgraph->nodes[i]->src[0 ]->ne[2];
11949+ cgraph->nodes[i]->src[0]->ne[0] * cgraph->nodes[i]->src[0]->ne[1] * cgraph->nodes[i]->src[1 ]->ne[2];
1180111950 auto NPQ_size = cgraph->nodes[i]->ne[0] * cgraph->nodes[i]->ne[1] * cgraph->nodes[i]->ne[3];
1180211951 total_mat_mul_bytes += NPQ_size * CRS_size * ggml_type_size(cgraph->nodes[i]->type);
1180311952 }
@@ -12618,10 +12767,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1261812767 case GGML_OP_CONV_TRANSPOSE_1D:
1261912768 return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
1262012769 case GGML_OP_CONV_2D:
12770+ case GGML_OP_CONV_TRANSPOSE_2D:
1262112771 {
1262212772 // Op is disabled for Apple because it segfaults at pipeline create time on MoltenVK
1262312773 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
1262412774 const vk_device& device = ggml_vk_get_device(ctx->device);
12775+ if (op->op == GGML_OP_CONV_TRANSPOSE_2D &&
12776+ device->properties.limits.maxPushConstantsSize < sizeof(vk_op_conv_transpose_2d_push_constants)) {
12777+ return false;
12778+ }
1262512779 // Channel-contiguous format is not supported yet.
1262612780 return ((op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
1262712781 op->src[1]->type == GGML_TYPE_F32 &&
@@ -13240,6 +13394,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1324013394 const int32_t d0 = tensor->op_params[4];
1324113395 const int32_t d1 = tensor->op_params[5];
1324213396 tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
13397+ } else if (tensor->op == GGML_OP_CONV_TRANSPOSE_2D) {
13398+ const int32_t s = tensor->op_params[0];
13399+ tensor_clone = ggml_conv_transpose_2d_p0(ggml_ctx, src_clone[0], src_clone[1], s);
1324313400 } else if (tensor->op == GGML_OP_LEAKY_RELU) {
1324413401 const float * op_params = (const float *)tensor->op_params;
1324513402 tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false);
0 commit comments