Skip to content

Commit 9d1b723

Browse files
committed
Vulkan: add conv_transpose_2d operation
1 parent 51abc96 commit 9d1b723

File tree

4 files changed

+536
-28
lines changed

4 files changed

+536
-28
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 182 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,8 @@ struct vk_device_struct {
574574
vk_pipeline pipeline_opt_step_sgd_f32;
575575
vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
576576
vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
577+
vk_pipeline pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT];
578+
vk_pipeline pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT];
577579
vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
578580
vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
579581

@@ -1117,6 +1119,56 @@ template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) {
11171119
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
11181120
}
11191121

1122+
struct vk_op_conv_transpose_2d_push_constants {
1123+
uint32_t Cout;
1124+
uint32_t Cin;
1125+
uint32_t N;
1126+
1127+
uint32_t KW;
1128+
uint32_t KH;
1129+
uint32_t W;
1130+
uint32_t H;
1131+
uint32_t OW;
1132+
uint32_t OH;
1133+
1134+
uint32_t s0;
1135+
uint32_t s1;
1136+
uint32_t p0;
1137+
uint32_t p1;
1138+
uint32_t d0;
1139+
uint32_t d1;
1140+
1141+
uint32_t nb01;
1142+
uint32_t nb02;
1143+
uint32_t nb03;
1144+
1145+
uint32_t nb11;
1146+
uint32_t nb12;
1147+
uint32_t nb13;
1148+
1149+
uint32_t nb1;
1150+
uint32_t nb2;
1151+
uint32_t nb3;
1152+
1153+
// init_fastdiv_values constants for dividing by KW, KW*KH, OW, OW*OH, s0, s1
1154+
uint32_t KWmp; uint32_t KWL;
1155+
uint32_t KWKHmp; uint32_t KWKHL;
1156+
uint32_t OWmp; uint32_t OWL;
1157+
uint32_t OWOHmp; uint32_t OWOHL;
1158+
uint32_t s0mp; uint32_t s0L;
1159+
uint32_t s1mp; uint32_t s1L;
1160+
};
1161+
1162+
template <> void init_pushconst_fastdiv(vk_op_conv_transpose_2d_push_constants &p) {
1163+
// Compute magic values to divide by KW, KW*KH, OW, OW*OH, s0, s1
1164+
init_fastdiv_values(p.KW, p.KWmp, p.KWL);
1165+
init_fastdiv_values(p.KW*p.KH, p.KWKHmp, p.KWKHL);
1166+
init_fastdiv_values(p.OW, p.OWmp, p.OWL);
1167+
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
1168+
init_fastdiv_values(p.s0, p.s0mp, p.s0L);
1169+
init_fastdiv_values(p.s1, p.s1mp, p.s1L);
1170+
}
1171+
11201172
struct vk_op_conv2d_dw_push_constants {
11211173
uint32_t ne;
11221174
uint32_t batches;
@@ -1313,7 +1365,7 @@ class vk_perf_logger {
13131365
flops[name].push_back(m * n * (k + (k - 1)) * batch);
13141366
return;
13151367
}
1316-
if (node->op == GGML_OP_CONV_2D) {
1368+
if (node->op == GGML_OP_CONV_2D || node->op == GGML_OP_CONV_TRANSPOSE_2D) {
13171369
std::string name = ggml_op_name(node->op);
13181370
ggml_tensor * knl = node->src[0];
13191371
uint64_t OW = node->ne[0];
@@ -1322,7 +1374,7 @@ class vk_perf_logger {
13221374
uint64_t Cout = node->ne[2];
13231375
uint64_t KW = knl->ne[0];
13241376
uint64_t KH = knl->ne[1];
1325-
uint64_t Cin = knl->ne[2];
1377+
uint64_t Cin = node->src[1]->ne[2];
13261378
// KxCRS @ CRSxNPQ = KxNPQ -> M=K, K=CRS, N=NPQ
13271379
uint64_t size_M = Cout;
13281380
uint64_t size_K = Cin * KW * KH;
@@ -3471,7 +3523,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
34713523

34723524
ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
34733525

3474-
// conv2d
3526+
// conv2d, conv_transpose_2d
34753527
for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
34763528
uint32_t conv2d_WG_SIZE = 256;
34773529
uint32_t conv2d_BS_K = 128;
@@ -3546,31 +3598,31 @@ static void ggml_vk_load_shaders(vk_device& device) {
35463598
std::array<uint32_t, 3> wg_denoms = { conv2d_BS_K, conv2d_BS_NPQ, 1 };
35473599
std::vector<uint32_t> spec_constants = { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD };
35483600

3601+
#define CREATE_CONV(name, type_suffix, spv_suffix) \
3602+
ggml_vk_create_pipeline( \
3603+
device, device->pipeline_##name##type_suffix[s], #name #type_suffix, \
3604+
name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \
3605+
sizeof(vk_op_##name##_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
35493606
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
35503607
if (device->coopmat2) {
3551-
ggml_vk_create_pipeline(
3552-
device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_cm2_len, conv2d_f32_cm2_data, "main", 3,
3553-
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3554-
ggml_vk_create_pipeline(
3555-
device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_cm2_len, conv2d_f16_f32_cm2_data, "main", 3,
3556-
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3608+
CREATE_CONV(conv2d, _f32, _cm2)
3609+
CREATE_CONV(conv2d, _f16_f32, _cm2)
3610+
CREATE_CONV(conv_transpose_2d, _f32, _cm2)
3611+
CREATE_CONV(conv_transpose_2d, _f16_f32, _cm2)
35573612
} else
35583613
#endif
35593614
if (conv2d_UNROLL) {
3560-
ggml_vk_create_pipeline(
3561-
device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_unroll_len, conv2d_f32_unroll_data, "main", 3,
3562-
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3563-
ggml_vk_create_pipeline(
3564-
device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_unroll_len, conv2d_f16_f32_unroll_data, "main", 3,
3565-
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3615+
CREATE_CONV(conv2d, _f32, _unroll)
3616+
CREATE_CONV(conv2d, _f16_f32, _unroll)
3617+
CREATE_CONV(conv_transpose_2d, _f32, _unroll)
3618+
CREATE_CONV(conv_transpose_2d, _f16_f32, _unroll)
35663619
} else {
3567-
ggml_vk_create_pipeline(
3568-
device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
3569-
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3570-
ggml_vk_create_pipeline(
3571-
device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3,
3572-
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3620+
CREATE_CONV(conv2d, _f32, )
3621+
CREATE_CONV(conv2d, _f16_f32, )
3622+
CREATE_CONV(conv_transpose_2d, _f32, )
3623+
CREATE_CONV(conv_transpose_2d, _f16_f32, )
35733624
}
3625+
#undef CREATE_CONV
35743626
}
35753627

35763628
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
@@ -7502,6 +7554,33 @@ static std::array<uint32_t, 3> ggml_vk_get_conv_elements(const ggml_tensor *dst)
75027554
return elements;
75037555
}
75047556

7557+
static std::array<uint32_t, 3> ggml_vk_get_conv_transpose_2d_elements(const ggml_tensor *dst) {
7558+
const ggml_tensor *src0 = dst->src[0];
7559+
const ggml_tensor *src1 = dst->src[1];
7560+
7561+
// src0 - kernel: [KW, KH, Cout, Cin]
7562+
// src1 - input: [W, H, Cin, N]
7563+
// dst - result: [OW, OH, Cout, N]
7564+
7565+
auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
7566+
return (ins - 1) * s - 2 * p + (ks - 1) * d + 1;
7567+
};
7568+
// parallelize in {OW/BS_K, OH/BS_NPQ, 1}
7569+
int64_t W = src1->ne[0];
7570+
int64_t H = src1->ne[1];
7571+
int64_t KW = src0->ne[0];
7572+
int64_t KH = src0->ne[1];
7573+
int64_t Cout = src0->ne[2];
7574+
int64_t N = src1->ne[3];
7575+
int64_t OH = calc_conv_output_size(H, KH, dst->op_params[0], 0, 1);
7576+
int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], 0, 1);
7577+
int64_t NPQ = N * OW * OH;
7578+
7579+
// Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups
7580+
std::array<uint32_t, 3> elements = { static_cast<uint32_t>(Cout), static_cast<uint32_t>(NPQ), 1 };
7581+
return elements;
7582+
}
7583+
75057584
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * dst, ggml_op op) {
75067585
switch (op) {
75077586
case GGML_OP_GET_ROWS:
@@ -7879,9 +7958,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
78797958
}
78807959
return nullptr;
78817960
case GGML_OP_CONV_2D:
7961+
case GGML_OP_CONV_TRANSPOSE_2D:
78827962
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
78837963
ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
7884-
auto elements = ggml_vk_get_conv_elements(dst);
7964+
std::array<uint32_t, 3> elements;
7965+
if (op == GGML_OP_CONV_2D) elements = ggml_vk_get_conv_elements(dst);
7966+
else if (op == GGML_OP_CONV_TRANSPOSE_2D) elements = ggml_vk_get_conv_transpose_2d_elements(dst);
78857967
vk_conv_shapes shape;
78867968

78877969
uint32_t tiles[CONV_SHAPE_COUNT];
@@ -7901,10 +7983,18 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
79017983
shape = CONV_SHAPE_64x32;
79027984
}
79037985

7904-
if (src0->type == GGML_TYPE_F32) {
7905-
return ctx->device->pipeline_conv2d_f32[shape];
7906-
} else if (src0->type == GGML_TYPE_F16) {
7907-
return ctx->device->pipeline_conv2d_f16_f32[shape];
7986+
if (op == GGML_OP_CONV_2D) {
7987+
if (src0->type == GGML_TYPE_F32) {
7988+
return ctx->device->pipeline_conv2d_f32[shape];
7989+
} else if (src0->type == GGML_TYPE_F16) {
7990+
return ctx->device->pipeline_conv2d_f16_f32[shape];
7991+
}
7992+
} else if (op == GGML_OP_CONV_TRANSPOSE_2D) {
7993+
if (src0->type == GGML_TYPE_F32) {
7994+
return ctx->device->pipeline_conv_transpose_2d_f32[shape];
7995+
} else if (src0->type == GGML_TYPE_F16) {
7996+
return ctx->device->pipeline_conv_transpose_2d_f16_f32[shape];
7997+
}
79087998
}
79097999
}
79108000
return nullptr;
@@ -8304,6 +8394,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
83048394
{
83058395
elements = ggml_vk_get_conv_elements(dst);
83068396
} break;
8397+
case GGML_OP_CONV_TRANSPOSE_2D:
8398+
{
8399+
elements = ggml_vk_get_conv_transpose_2d_elements(dst);
8400+
} break;
83078401
case GGML_OP_ADD:
83088402
case GGML_OP_SUB:
83098403
case GGML_OP_DIV:
@@ -9477,6 +9571,55 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx,
94779571
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D, std::move(p), dryrun);
94789572
}
94799573

9574+
static void ggml_vk_conv_transpose_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
9575+
const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
9576+
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
9577+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
9578+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
9579+
9580+
GGML_TENSOR_BINARY_OP_LOCALS
9581+
9582+
GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t));
9583+
GGML_ASSERT(nb10 == sizeof(float));
9584+
GGML_ASSERT(nb0 == sizeof(float));
9585+
9586+
vk_op_conv_transpose_2d_push_constants p{};
9587+
p.Cout = static_cast<uint32_t>(ne02);
9588+
p.Cin = static_cast<uint32_t>(ne03);
9589+
p.N = static_cast<uint32_t>(ne13);
9590+
9591+
p.KW = static_cast<uint32_t>(ne00);
9592+
p.KH = static_cast<uint32_t>(ne01);
9593+
p.W = static_cast<uint32_t>(ne10);
9594+
p.H = static_cast<uint32_t>(ne11);
9595+
p.OW = static_cast<uint32_t>(ne0);
9596+
p.OH = static_cast<uint32_t>(ne1);
9597+
9598+
p.s0 = static_cast<uint32_t>(dst->op_params[0]);
9599+
p.s1 = static_cast<uint32_t>(dst->op_params[0]);
9600+
p.p0 = 0;
9601+
p.p1 = 0;
9602+
p.d0 = 1;
9603+
p.d1 = 1;
9604+
9605+
p.nb01 = static_cast<uint32_t>(nb01 / nb00);
9606+
p.nb02 = static_cast<uint32_t>(nb02 / nb00);
9607+
p.nb03 = static_cast<uint32_t>(nb03 / nb00);
9608+
9609+
p.nb11 = static_cast<uint32_t>(nb11 / nb10);
9610+
p.nb12 = static_cast<uint32_t>(nb12 / nb10);
9611+
p.nb13 = static_cast<uint32_t>(nb13 / nb10);
9612+
9613+
p.nb1 = static_cast<uint32_t>(nb1 / nb0);
9614+
p.nb2 = static_cast<uint32_t>(nb2 / nb0);
9615+
p.nb3 = static_cast<uint32_t>(nb3 / nb0);
9616+
9617+
GGML_ASSERT(ne02 == ne2);
9618+
GGML_ASSERT(ne03 == ne12);
9619+
9620+
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_TRANSPOSE_2D, std::move(p), dryrun);
9621+
}
9622+
94809623
static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
94819624
vk_op_conv2d_dw_push_constants p{};
94829625
p.ne = ggml_nelements(dst);
@@ -10569,6 +10712,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1056910712
case GGML_OP_CONV_TRANSPOSE_1D:
1057010713
case GGML_OP_POOL_2D:
1057110714
case GGML_OP_CONV_2D:
10715+
case GGML_OP_CONV_TRANSPOSE_2D:
1057210716
case GGML_OP_CONV_2D_DW:
1057310717
case GGML_OP_RWKV_WKV6:
1057410718
case GGML_OP_RWKV_WKV7:
@@ -10640,6 +10784,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1064010784
case GGML_OP_CONV_TRANSPOSE_1D:
1064110785
case GGML_OP_POOL_2D:
1064210786
case GGML_OP_CONV_2D:
10787+
case GGML_OP_CONV_TRANSPOSE_2D:
1064310788
case GGML_OP_CONV_2D_DW:
1064410789
case GGML_OP_LEAKY_RELU:
1064510790
case GGML_OP_OPT_STEP_SGD:
@@ -10951,6 +11096,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1095111096
case GGML_OP_CONV_2D:
1095211097
ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node, dryrun);
1095311098

11099+
break;
11100+
case GGML_OP_CONV_TRANSPOSE_2D:
11101+
ggml_vk_conv_transpose_2d(ctx, compute_ctx, src0, src1, node, dryrun);
11102+
1095411103
break;
1095511104
case GGML_OP_CONV_2D_DW:
1095611105
ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -11091,6 +11240,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1109111240
case GGML_OP_CONV_TRANSPOSE_1D:
1109211241
case GGML_OP_POOL_2D:
1109311242
case GGML_OP_CONV_2D:
11243+
case GGML_OP_CONV_TRANSPOSE_2D:
1109411244
case GGML_OP_CONV_2D_DW:
1109511245
case GGML_OP_RWKV_WKV6:
1109611246
case GGML_OP_RWKV_WKV7:
@@ -11743,10 +11893,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1174311893
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
1174411894
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
1174511895
total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
11746-
} else if (cgraph->nodes[i]->op == GGML_OP_CONV_2D) {
11896+
} else if (cgraph->nodes[i]->op == GGML_OP_CONV_2D || cgraph->nodes[i]->op == GGML_OP_CONV_TRANSPOSE_2D) {
1174711897
// Return CRSxNPQxsizeof(*) to account as many bytes as mul_mat has in im2col->mul_mat mode.
1174811898
auto CRS_size =
11749-
cgraph->nodes[i]->src[0]->ne[0] * cgraph->nodes[i]->src[0]->ne[1] * cgraph->nodes[i]->src[0]->ne[2];
11899+
cgraph->nodes[i]->src[0]->ne[0] * cgraph->nodes[i]->src[0]->ne[1] * cgraph->nodes[i]->src[1]->ne[2];
1175011900
auto NPQ_size = cgraph->nodes[i]->ne[0] * cgraph->nodes[i]->ne[1] * cgraph->nodes[i]->ne[3];
1175111901
total_mat_mul_bytes += NPQ_size * CRS_size * ggml_type_size(cgraph->nodes[i]->type);
1175211902
}
@@ -12567,6 +12717,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1256712717
case GGML_OP_CONV_TRANSPOSE_1D:
1256812718
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
1256912719
case GGML_OP_CONV_2D:
12720+
case GGML_OP_CONV_TRANSPOSE_2D:
1257012721
{
1257112722
// Op is disabled for Apple because it segfaults at pipeline create time on MoltenVK
1257212723
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
@@ -13175,6 +13326,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1317513326
const int32_t d0 = tensor->op_params[4];
1317613327
const int32_t d1 = tensor->op_params[5];
1317713328
tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
13329+
} else if (tensor->op == GGML_OP_CONV_TRANSPOSE_2D) {
13330+
const int32_t s = tensor->op_params[0];
13331+
tensor_clone = ggml_conv_transpose_2d_p0(ggml_ctx, src_clone[0], src_clone[1], s);
1317813332
} else if (tensor->op == GGML_OP_LEAKY_RELU) {
1317913333
const float * op_params = (const float *)tensor->op_params;
1318013334
tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false);

0 commit comments

Comments
 (0)