@@ -398,6 +398,7 @@ struct vk_device_struct {
398398 vk_pipeline pipeline_count_equal_i32;
399399 vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
400400 vk_pipeline pipeline_timestep_embedding_f32;
401+ vk_pipeline pipeline_conv_transpose_1d_f32;
401402 vk_pipeline pipeline_pool2d_f32;
402403 vk_pipeline pipeline_rwkv_wkv6_f32;
403404 vk_pipeline pipeline_rwkv_wkv7_f32;
@@ -706,6 +707,21 @@ struct vk_op_timestep_embedding_push_constants {
706707 uint32_t max_period;
707708};
708709
710+ struct vk_op_conv_transpose_1d_push_constants {
711+ uint32_t Cout;
712+ uint32_t Cin;
713+ uint32_t K;
714+ uint32_t L;
715+ uint32_t KL;
716+
717+ uint32_t nb01;
718+ uint32_t nb02;
719+ uint32_t nb11;
720+ uint32_t nb1;
721+
722+ int32_t s0;
723+ };
724+
709725struct vk_op_pool2d_push_constants {
710726 uint32_t IW; uint32_t IH;
711727 uint32_t OW; uint32_t OH;
@@ -2727,6 +2743,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
27272743
27282744 ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
27292745
2746+ ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, "conv_transpose_1d_f32", conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, "main", 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1);
2747+
27302748 ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);
27312749
27322750 ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
@@ -6391,6 +6409,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
63916409 return ctx->device->pipeline_timestep_embedding_f32;
63926410 }
63936411 return nullptr;
6412+ case GGML_OP_CONV_TRANSPOSE_1D:
6413+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6414+ return ctx->device->pipeline_conv_transpose_1d_f32;
6415+ }
6416+ return nullptr;
63946417 case GGML_OP_POOL_2D:
63956418 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
63966419 return ctx->device->pipeline_pool2d_f32;
@@ -6725,6 +6748,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
67256748 uint32_t half_ceil = (dim + 1) / 2;
67266749 elements = { half_ceil, (uint32_t)src0->ne[0], 1 };
67276750 } break;
6751+ case GGML_OP_CONV_TRANSPOSE_1D:
6752+ {
6753+ elements = {uint32_t(src0->ne[1]), 1, 1}; // parallelize in {Cout, 1, 1}
6754+ } break;
67286755 case GGML_OP_POOL_2D:
67296756 {
67306757 const uint32_t N = dst->ne[3];
@@ -7528,6 +7555,37 @@ static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context
75287555 }, dryrun);
75297556}
75307557
7558+ static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7559+ // src0: (K, Cout, Cin, 1) -- kernel
7560+ // src1: (L, Cin, 1, 1) -- input
7561+ // dst: (*, Cout, 1, 1)
7562+
7563+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
7564+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
7565+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
7566+
7567+ GGML_TENSOR_BINARY_OP_LOCALS
7568+
7569+ GGML_ASSERT(nb00 == sizeof(float));
7570+ GGML_ASSERT(nb10 == sizeof(float));
7571+
7572+ const int32_t s0 = dst->op_params[0];
7573+
7574+ vk_op_conv_transpose_1d_push_constants p{};
7575+ p.Cout = ne01;
7576+ p.Cin = ne02;
7577+ p.K = ne00;
7578+ p.L = ne10;
7579+ p.KL = ne0;
7580+ p.nb01 = nb01 / nb00;
7581+ p.nb02 = nb02 / nb00;
7582+ p.nb11 = nb11 / nb10;
7583+ p.nb1 = nb1 / nb0;
7584+ p.s0 = s0;
7585+
7586+ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p), dryrun);
7587+ }
7588+
75317589static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
75327590 uint32_t op = static_cast<uint32_t>(dst->op_params[0]);
75337591 const int32_t k1 = dst->op_params[1];
@@ -8599,6 +8657,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
85998657 case GGML_OP_COUNT_EQUAL:
86008658 case GGML_OP_IM2COL:
86018659 case GGML_OP_TIMESTEP_EMBEDDING:
8660+ case GGML_OP_CONV_TRANSPOSE_1D:
86028661 case GGML_OP_POOL_2D:
86038662 case GGML_OP_CONV_2D_DW:
86048663 case GGML_OP_RWKV_WKV6:
@@ -8663,6 +8722,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
86638722 case GGML_OP_COUNT_EQUAL:
86648723 case GGML_OP_IM2COL:
86658724 case GGML_OP_TIMESTEP_EMBEDDING:
8725+ case GGML_OP_CONV_TRANSPOSE_1D:
86668726 case GGML_OP_POOL_2D:
86678727 case GGML_OP_CONV_2D_DW:
86688728 case GGML_OP_LEAKY_RELU:
@@ -8834,6 +8894,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
88348894 case GGML_OP_TIMESTEP_EMBEDDING:
88358895 ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun);
88368896
8897+ break;
8898+ case GGML_OP_CONV_TRANSPOSE_1D:
8899+ ggml_vk_conv_transpose_1d(ctx, compute_ctx, src0, src1, node, dryrun);
8900+
88378901 break;
88388902 case GGML_OP_POOL_2D:
88398903 ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
@@ -8962,6 +9026,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
89629026 case GGML_OP_COUNT_EQUAL:
89639027 case GGML_OP_IM2COL:
89649028 case GGML_OP_TIMESTEP_EMBEDDING:
9029+ case GGML_OP_CONV_TRANSPOSE_1D:
89659030 case GGML_OP_POOL_2D:
89669031 case GGML_OP_CONV_2D_DW:
89679032 case GGML_OP_RWKV_WKV6:
@@ -9964,6 +10029,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
996410029 case GGML_OP_COUNT_EQUAL:
996510030 case GGML_OP_IM2COL:
996610031 case GGML_OP_TIMESTEP_EMBEDDING:
10032+ case GGML_OP_CONV_TRANSPOSE_1D:
10033+ return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
996710034 case GGML_OP_CONV_2D_DW:
996810035 case GGML_OP_POOL_2D:
996910036 case GGML_OP_RWKV_WKV6:
@@ -10462,6 +10529,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
1046210529 const int32_t dim = tensor->op_params[0];
1046310530 const int32_t max_period = tensor->op_params[1];
1046410531 tensor_clone = ggml_timestep_embedding(ggml_ctx, src_clone[0], dim, max_period);
10532+ } else if (tensor->op == GGML_OP_CONV_TRANSPOSE_1D){
10533+ const int32_t s0 = tensor->op_params[0];
10534+ const int32_t p0 = tensor->op_params[1];
10535+ const int32_t d0 = tensor->op_params[2];
10536+ tensor_clonse = ggml_conv_transpose_1d(ggml_ctx, src_clone[0], src_clone[1], s0, p0, d0);
1046510537 } else if (tensor->op == GGML_OP_POOL_2D) {
1046610538 enum ggml_op_pool op = static_cast<ggml_op_pool>(tensor->op_params[0]);
1046710539 const int32_t k0 = tensor->op_params[1];
0 commit comments