From 9d1b7233e175016230cae30149ac3bbaa9518cec Mon Sep 17 00:00:00 2001 From: Shin-myoung-serp Date: Tue, 16 Sep 2025 15:50:44 +0900 Subject: [PATCH 1/7] Vulkan: add conv_transpose_2d operation --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 210 +++++++++-- .../vulkan-shaders/conv_transpose_2d_mm.comp | 339 ++++++++++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 11 + tests/test-backend-ops.cpp | 4 + 4 files changed, 536 insertions(+), 28 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_2d_mm.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 60a99dc78b836..218861475ecd6 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -574,6 +574,8 @@ struct vk_device_struct { vk_pipeline pipeline_opt_step_sgd_f32; vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT]; vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT]; + vk_pipeline pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT]; + vk_pipeline pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT]; vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32; vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32; @@ -1117,6 +1119,56 @@ template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) { init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL); } +struct vk_op_conv_transpose_2d_push_constants { + uint32_t Cout; + uint32_t Cin; + uint32_t N; + + uint32_t KW; + uint32_t KH; + uint32_t W; + uint32_t H; + uint32_t OW; + uint32_t OH; + + uint32_t s0; + uint32_t s1; + uint32_t p0; + uint32_t p1; + uint32_t d0; + uint32_t d1; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + + uint32_t nb1; + uint32_t nb2; + uint32_t nb3; + + // init_fastdiv_values constants for dividing by KW, KW*KH, OW, OW*OH, s0, s1 + uint32_t KWmp; uint32_t KWL; + uint32_t KWKHmp; uint32_t KWKHL; + uint32_t OWmp; uint32_t OWL; + uint32_t OWOHmp; uint32_t OWOHL; + uint32_t s0mp; uint32_t s0L; + uint32_t s1mp; uint32_t s1L; +}; + +template <> void init_pushconst_fastdiv(vk_op_conv_transpose_2d_push_constants &p) { + // Compute magic values to divide by KW, KW*KH, OW, OW*OH, s0, s1 + init_fastdiv_values(p.KW, p.KWmp, p.KWL); + init_fastdiv_values(p.KW*p.KH, p.KWKHmp, p.KWKHL); + init_fastdiv_values(p.OW, p.OWmp, p.OWL); + init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL); + init_fastdiv_values(p.s0, p.s0mp, p.s0L); + init_fastdiv_values(p.s1, p.s1mp, p.s1L); +} + struct vk_op_conv2d_dw_push_constants { uint32_t ne; uint32_t batches; @@ -1313,7 +1365,7 @@ class vk_perf_logger { flops[name].push_back(m * n * (k + (k - 1)) * batch); return; } - if (node->op == GGML_OP_CONV_2D) { + if (node->op == GGML_OP_CONV_2D || node->op == GGML_OP_CONV_TRANSPOSE_2D) { std::string name = ggml_op_name(node->op); ggml_tensor * knl = node->src[0]; uint64_t OW = node->ne[0]; @@ -1322,7 +1374,7 @@ class vk_perf_logger { uint64_t Cout = node->ne[2]; uint64_t KW = knl->ne[0]; uint64_t KH = knl->ne[1]; - uint64_t Cin = knl->ne[2]; + uint64_t Cin = node->src[1]->ne[2]; // KxCRS @ CRSxNPQ = KxNPQ -> M=K, K=CRS, N=NPQ uint64_t size_M = Cout; uint64_t size_K = Cin * KW * KH; @@ -3471,7 +3523,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - // conv2d + // conv2d, conv_transpose_2d for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) { uint32_t conv2d_WG_SIZE = 256; uint32_t conv2d_BS_K = 128; @@ -3546,31 +3598,31 @@ static void ggml_vk_load_shaders(vk_device& device) { std::array wg_denoms = { conv2d_BS_K, conv2d_BS_NPQ, 1 }; std::vector spec_constants = { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD }; +#define CREATE_CONV(name, type_suffix, spv_suffix) \ + ggml_vk_create_pipeline( \ + device, device->pipeline_##name##type_suffix[s], #name #type_suffix, \ + name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \ + sizeof(vk_op_##name##_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) if (device->coopmat2) { - ggml_vk_create_pipeline( - device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_cm2_len, conv2d_f32_cm2_data, "main", 3, - sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); - ggml_vk_create_pipeline( - device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_cm2_len, conv2d_f16_f32_cm2_data, "main", 3, - sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); + CREATE_CONV(conv2d, _f32, _cm2) + CREATE_CONV(conv2d, _f16_f32, _cm2) + CREATE_CONV(conv_transpose_2d, _f32, _cm2) + CREATE_CONV(conv_transpose_2d, _f16_f32, _cm2) } else #endif if (conv2d_UNROLL) { - ggml_vk_create_pipeline( - device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_unroll_len, conv2d_f32_unroll_data, "main", 3, - sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); - ggml_vk_create_pipeline( - device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_unroll_len, conv2d_f16_f32_unroll_data, "main", 3, - sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); + CREATE_CONV(conv2d, _f32, _unroll) + CREATE_CONV(conv2d, _f16_f32, _unroll) + CREATE_CONV(conv_transpose_2d, _f32, _unroll) + CREATE_CONV(conv_transpose_2d, _f16_f32, _unroll) } else { - ggml_vk_create_pipeline( - device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, - sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); - ggml_vk_create_pipeline( - device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3, - sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); + CREATE_CONV(conv2d, _f32, ) + CREATE_CONV(conv2d, _f16_f32, ) + CREATE_CONV(conv_transpose_2d, _f32, ) + CREATE_CONV(conv_transpose_2d, _f16_f32, ) } +#undef CREATE_CONV } ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); @@ -7502,6 +7554,33 @@ static std::array ggml_vk_get_conv_elements(const ggml_tensor *dst) return elements; } +static std::array ggml_vk_get_conv_transpose_2d_elements(const ggml_tensor *dst) { + const ggml_tensor *src0 = dst->src[0]; + const ggml_tensor *src1 = dst->src[1]; + + // src0 - kernel: [KW, KH, Cout, Cin] + // src1 - input: [W, H, Cin, N] + // dst - result: [OW, OH, Cout, N] + + auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t { + return (ins - 1) * s - 2 * p + (ks - 1) * d + 1; + }; + // parallelize in {OW/BS_K, OH/BS_NPQ, 1} + int64_t W = src1->ne[0]; + int64_t H = src1->ne[1]; + int64_t KW = src0->ne[0]; + int64_t KH = src0->ne[1]; + int64_t Cout = src0->ne[2]; + int64_t N = src1->ne[3]; + int64_t OH = calc_conv_output_size(H, KH, dst->op_params[0], 0, 1); + int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], 0, 1); + int64_t NPQ = N * OW * OH; + + // Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups + std::array elements = { static_cast(Cout), static_cast(NPQ), 1 }; + return elements; +} + static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * dst, ggml_op op) { switch (op) { case GGML_OP_GET_ROWS: @@ -7879,9 +7958,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; case GGML_OP_CONV_2D: + case GGML_OP_CONV_TRANSPOSE_2D: if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { - auto elements = ggml_vk_get_conv_elements(dst); + std::array elements; + if (op == GGML_OP_CONV_2D) elements = ggml_vk_get_conv_elements(dst); + else if (op == GGML_OP_CONV_TRANSPOSE_2D) elements = ggml_vk_get_conv_transpose_2d_elements(dst); vk_conv_shapes shape; uint32_t tiles[CONV_SHAPE_COUNT]; @@ -7901,10 +7983,18 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const shape = CONV_SHAPE_64x32; } - if (src0->type == GGML_TYPE_F32) { - return ctx->device->pipeline_conv2d_f32[shape]; - } else if (src0->type == GGML_TYPE_F16) { - return ctx->device->pipeline_conv2d_f16_f32[shape]; + if (op == GGML_OP_CONV_2D) { + if (src0->type == GGML_TYPE_F32) { + return ctx->device->pipeline_conv2d_f32[shape]; + } else if (src0->type == GGML_TYPE_F16) { + return ctx->device->pipeline_conv2d_f16_f32[shape]; + } + } else if (op == GGML_OP_CONV_TRANSPOSE_2D) { + if (src0->type == GGML_TYPE_F32) { + return ctx->device->pipeline_conv_transpose_2d_f32[shape]; + } else if (src0->type == GGML_TYPE_F16) { + return ctx->device->pipeline_conv_transpose_2d_f16_f32[shape]; + } } } return nullptr; @@ -8304,6 +8394,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co { elements = ggml_vk_get_conv_elements(dst); } break; + case GGML_OP_CONV_TRANSPOSE_2D: + { + elements = ggml_vk_get_conv_transpose_2d_elements(dst); + } break; case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_DIV: @@ -9477,6 +9571,55 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D, std::move(p), dryrun); } +static void ggml_vk_conv_transpose_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0, + const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + GGML_ASSERT(nb0 == sizeof(float)); + + vk_op_conv_transpose_2d_push_constants p{}; + p.Cout = static_cast(ne02); + p.Cin = static_cast(ne03); + p.N = static_cast(ne13); + + p.KW = static_cast(ne00); + p.KH = static_cast(ne01); + p.W = static_cast(ne10); + p.H = static_cast(ne11); + p.OW = static_cast(ne0); + p.OH = static_cast(ne1); + + p.s0 = static_cast(dst->op_params[0]); + p.s1 = static_cast(dst->op_params[0]); + p.p0 = 0; + p.p1 = 0; + p.d0 = 1; + p.d1 = 1; + + p.nb01 = static_cast(nb01 / nb00); + p.nb02 = static_cast(nb02 / nb00); + p.nb03 = static_cast(nb03 / nb00); + + p.nb11 = static_cast(nb11 / nb10); + p.nb12 = static_cast(nb12 / nb10); + p.nb13 = static_cast(nb13 / nb10); + + p.nb1 = static_cast(nb1 / nb0); + p.nb2 = static_cast(nb2 / nb0); + p.nb3 = static_cast(nb3 / nb0); + + GGML_ASSERT(ne02 == ne2); + GGML_ASSERT(ne03 == ne12); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_TRANSPOSE_2D, std::move(p), dryrun); +} + static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { vk_op_conv2d_dw_push_constants p{}; p.ne = ggml_nelements(dst); @@ -10569,6 +10712,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_POOL_2D: case GGML_OP_CONV_2D: + case GGML_OP_CONV_TRANSPOSE_2D: case GGML_OP_CONV_2D_DW: case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: @@ -10640,6 +10784,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_POOL_2D: case GGML_OP_CONV_2D: + case GGML_OP_CONV_TRANSPOSE_2D: case GGML_OP_CONV_2D_DW: case GGML_OP_LEAKY_RELU: case GGML_OP_OPT_STEP_SGD: @@ -10951,6 +11096,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_CONV_2D: ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node, dryrun); + break; + case GGML_OP_CONV_TRANSPOSE_2D: + ggml_vk_conv_transpose_2d(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_CONV_2D_DW: ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun); @@ -11091,6 +11240,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_POOL_2D: case GGML_OP_CONV_2D: + case GGML_OP_CONV_TRANSPOSE_2D: case GGML_OP_CONV_2D_DW: case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: @@ -11743,10 +11893,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false); if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) { total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]); - } else if (cgraph->nodes[i]->op == GGML_OP_CONV_2D) { + } else if (cgraph->nodes[i]->op == GGML_OP_CONV_2D || cgraph->nodes[i]->op == GGML_OP_CONV_TRANSPOSE_2D) { // Return CRSxNPQxsizeof(*) to account as many bytes as mul_mat has in im2col->mul_mat mode. auto CRS_size = - cgraph->nodes[i]->src[0]->ne[0] * cgraph->nodes[i]->src[0]->ne[1] * cgraph->nodes[i]->src[0]->ne[2]; + cgraph->nodes[i]->src[0]->ne[0] * cgraph->nodes[i]->src[0]->ne[1] * cgraph->nodes[i]->src[1]->ne[2]; auto NPQ_size = cgraph->nodes[i]->ne[0] * cgraph->nodes[i]->ne[1] * cgraph->nodes[i]->ne[3]; total_mat_mul_bytes += NPQ_size * CRS_size * ggml_type_size(cgraph->nodes[i]->type); } @@ -12567,6 +12717,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_CONV_TRANSPOSE_1D: return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; case GGML_OP_CONV_2D: + case GGML_OP_CONV_TRANSPOSE_2D: { // Op is disabled for Apple because it segfaults at pipeline create time on MoltenVK ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; @@ -13175,6 +13326,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * const int32_t d0 = tensor->op_params[4]; const int32_t d1 = tensor->op_params[5]; tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1); + } else if (tensor->op == GGML_OP_CONV_TRANSPOSE_2D) { + const int32_t s = tensor->op_params[0]; + tensor_clone = ggml_conv_transpose_2d_p0(ggml_ctx, src_clone[0], src_clone[1], s); } else if (tensor->op == GGML_OP_LEAKY_RELU) { const float * op_params = (const float *)tensor->op_params; tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_2d_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_2d_mm.comp new file mode 100644 index 0000000000000..87e9142f39534 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_2d_mm.comp @@ -0,0 +1,339 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#ifdef COOPMAT2 +#extension GL_NV_cooperative_matrix2 : enable +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_KHR_memory_scope_semantics : enable +#endif + +#ifdef USE_COLLECTIVES +# extension GL_KHR_shader_subgroup_shuffle : enable +#endif + +#include "types.comp" + +// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j +layout(binding = 0) readonly buffer A { + A_TYPE knl_data[]; +}; // src0 - kernel: [KW, KH, Cout, Cin] + +layout(binding = 1) readonly buffer B { + B_TYPE src_data[]; +}; // src1 - input: [W, H, Cin, N] -- channel_first format + +layout(binding = 2) writeonly buffer D { + D_TYPE dst_data[]; +}; // dst - result: [OW, OH, Cout, N] + +layout(push_constant) uniform parameter { + // I/O channels, batch size + uint32_t Cout; + uint32_t Cin; + uint32_t N; + + // Tensor spatial sizes: kernel, input, output + uint32_t KW; + uint32_t KH; + uint32_t W; + uint32_t H; + uint32_t OW; + uint32_t OH; + + // Parameters: stride, padding, dilation - 0=y, 1=x + uint32_t s0; + uint32_t s1; + uint32_t p0; + uint32_t p1; + uint32_t d0; + uint32_t d1; + + // Strides in elements + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + + uint32_t nb1; + uint32_t nb2; + uint32_t nb3; + + // fastdiv helper values + uint32_t KWmp; uint32_t KWL; + uint32_t KWKHmp; uint32_t KWKHL; + uint32_t OWmp; uint32_t OWL; + uint32_t OWOHmp; uint32_t OWOHL; + uint32_t s0mp; uint32_t s0L; + uint32_t s1mp; uint32_t s1L; +} + +p; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +// Blocktile sizes +layout(constant_id = 1) const uint BS_K = 128; +layout(constant_id = 2) const uint BS_CRS = 16; +layout(constant_id = 3) const uint BS_NPQ = 128; +// Thread-tile sizes +layout(constant_id = 4) const uint TS_K = 8; +layout(constant_id = 5) const uint use_collectives = 1; +layout(constant_id = 6) const uint SHMEM_PAD = 4; + +uint32_t tid = gl_LocalInvocationID.x; +const uint32_t WG_SIZE = gl_WorkGroupSize.x; + +uint splitWork(uint work_size, uint block_size) { + return (block_size + work_size - 1) / block_size; +} + +uint32_t K = p.Cout; +uint32_t CRS = p.Cin * p.KH * p.KW; +uint32_t NPQ = p.N * p.OH * p.OW; + +uint32_t n_elems_out = K * NPQ; + +// Number of blocktiles per input +uint32_t NB_CRS = splitWork(CRS, BS_CRS); + +#ifdef COOPMAT2 +#define SHMEM_TYPE float16_t +#else +#define SHMEM_TYPE float +#endif + +const uint32_t Ash_stride = BS_CRS + SHMEM_PAD; +const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD; + +const uint32_t Ash_numel = BS_K * BS_CRS; +const uint32_t Bsh_numel = BS_CRS * BS_NPQ; + +const uint32_t Ash_len = BS_K * Ash_stride; +const uint32_t Bsh_len = BS_CRS * Bsh_stride; + +shared SHMEM_TYPE Ash[Ash_len]; // K x CRS +shared SHMEM_TYPE Bsh[Bsh_len]; // CRS x NPQ + +// Threadtile sizes +const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K; + +// Number of threadtiles per blocktile +const uint32_t NT_K = BS_K / TS_K; +const uint32_t NT_NPQ = BS_NPQ / TS_NPQ; + +/* +Compute +KxCRS @ CRSxNPQ = K x NPQ +K=Cout +C=Cin +R,S=KH,KW +P,Q=OH,OW +*/ + +uint32_t B_idx_K = gl_WorkGroupID.x; +uint32_t B_idx_NPQ = gl_WorkGroupID.y; + +uint32_t T_y = tid / NT_NPQ; +uint32_t T_x = tid % NT_NPQ; + +uint32_t Ar = tid / BS_CRS; +uint32_t Ac = tid % BS_CRS; +const uint32_t ArpWg = WG_SIZE / BS_CRS; + +uint32_t Br = tid / BS_NPQ; +uint32_t Bc = tid % BS_NPQ; +const uint32_t BrpWg = WG_SIZE / BS_NPQ; + +// see init_fastdiv_values in ggml-vulkan.cpp +uint fastdiv(uint n, uint mp, uint L) { + uint msbs, lsbs; + // msbs = mulhi(n, mp) + umulExtended(n, mp, msbs, lsbs); + return (msbs + n) >> L; +} + +#ifdef COOPMAT2 +#define ACC_TYPE float16_t + +ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem) +{ + uint32_t K_idx = B_idx_K * BS_K + r; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + c; + uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW; + uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW; + uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW; + uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3; + if (K_idx < K && NPQ_idx < NPQ) { + dst_data[dst_idx] = D_TYPE(elem); + } + return elem; +} +#endif + +void main() { +#ifdef COOPMAT2 + coopmat matC; + matC = coopmat(0.0); +#else + float regC[TS_K][TS_NPQ]; + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + regC[T_ly][T_lx] = 0.0; + } + } +#endif + /* Advance block in CRS dim */ + for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) { + uint32_t CRS_idx_a; + uint32_t Cin_idx_a; + uint32_t KH_idx_a; + uint32_t KW_idx_a; + +#ifdef USE_COLLECTIVES + uint32_t cached_CRS_idx; + uint32_t cached_Cin_idx; + uint32_t cached_KH_idx; + uint32_t cached_KW_idx; + if (use_collectives == 1) { + cached_CRS_idx = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID; + cached_Cin_idx = fastdiv(cached_CRS_idx, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); + uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx * p.KW * p.KH); + cached_KH_idx = fastdiv(cached_CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; + cached_KW_idx = cached_CRS_remainder - cached_KH_idx * p.KW; + + CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac); + Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac); + KH_idx_a = subgroupShuffle(cached_KH_idx, Ac); + KW_idx_a = subgroupShuffle(cached_KW_idx, Ac); + } else { + CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A) + Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); + uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH; + KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; + KW_idx_a = CRS_remainder - KH_idx_a * p.KW; + } +#else + CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A) + Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); / (p.KW * p.KH); + CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH; + KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; + KW_idx_a = CRS_remainder - KH_idx_a * p.KW; +#endif + + /* Load kernel to A_block: (BS_K x BS_CRS)*/ + for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) { + uint32_t B_ly = r_offset + Ar; + uint32_t B_lx = Ac; + uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/ + float val; + if (K_idx >= K || CRS_idx_a >= CRS) { + val = 0.0; + } else { + uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + K_idx * p.nb02 + Cin_idx_a * p.nb03, K * CRS - 1); + val = knl_data[knl_idx]; + } + Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val); + } + /* Load input to B_block: (BS_CRS x BS_NPQ) */ + UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) { + uint32_t B_ly = r_offset + Br; /* Row index of B block */ + uint32_t B_lx = Bc; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */ + uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW; + uint32_t NPQ_remainder = NPQ_idx - N_idx * p.OH * p.OW; + uint32_t OH_idx = fastdiv(NPQ_remainder, p.OWmp, p.OWL); // divide by p.OW; + uint32_t OW_idx = NPQ_remainder - OH_idx * p.OW; + + uint32_t CRS_idx_b; + uint32_t Cin_idx_b; + uint32_t KH_idx_b; + uint32_t KW_idx_b; +#ifdef USE_COLLECTIVES + if (use_collectives == 1) { + CRS_idx_b = subgroupShuffle(cached_CRS_idx, r_offset + Br); + Cin_idx_b = subgroupShuffle(cached_Cin_idx, r_offset + Br); + KH_idx_b = subgroupShuffle(cached_KH_idx, r_offset + Br); + KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br); + } else { + CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */ + Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); + uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH; + KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; + KW_idx_b = CRS_remainder - KH_idx_b * p.KW; + } +#else + CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */ + Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); + uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH; + KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; + KW_idx_b = CRS_remainder - KH_idx_b * p.KW; +#endif + + uint32_t H_idx_x_s1 = OH_idx - KH_idx_b * p.d1 + p.p1; + uint32_t W_idx_x_s0 = OW_idx - KW_idx_b * p.d0 + p.p0; + uint32_t H_idx = fastdiv(H_idx_x_s1, p.s0mp, p.s0L); + uint32_t W_idx = fastdiv(W_idx_x_s0, p.s1mp, p.s1L); + float val; + if (CRS_idx_b >= CRS || NPQ_idx >= NPQ || + int32_t(H_idx) < 0 || H_idx >= p.H || int32_t(W_idx) < 0 || W_idx >= p.W || + (H_idx_x_s1 - H_idx * p.s1 != 0) || (W_idx_x_s0 - W_idx * p.s0 != 0)) { + val = 0.0; + } else { + uint32_t src_idx = + min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1); + val = src_data[src_idx]; + } + Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val); + } + barrier(); +#ifdef COOPMAT2 + coopmat matA; + coopmat matB; + + coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor); + coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor); + matC = coopMatMulAdd(matA, matB, matC); +#else + if (T_y * TS_K < K) { + UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) { + float regA[TS_K]; + float regB[TS_NPQ]; + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx]; + } + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx]; + } + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]); + } + } + } + } +#endif + barrier(); + } + /* Save C* */ +#ifdef COOPMAT2 + coopMatPerElementNV(matC, matC, perElemOpStore); +#else + if (T_y * TS_K < K) { + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx; + uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW; + uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW; + uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW; + uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3; + if (K_idx < K && NPQ_idx < NPQ) { + dst_data[dst_idx] = regC[T_ly][T_lx]; + } + } + } + } +#endif +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index e818166d1c2a2..30b8e8a776c2f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -804,6 +804,17 @@ void process_shaders() { string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}, {"COOPMAT2", "1"}}, true, false, true); #endif + string_to_spv("conv_transpose_2d_f32_unroll", "conv_transpose_2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}}); + string_to_spv("conv_transpose_2d_f16_f32_unroll", "conv_transpose_2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}}); + + string_to_spv("conv_transpose_2d_f32", "conv_transpose_2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", ""}}); + string_to_spv("conv_transpose_2d_f16_f32", "conv_transpose_2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", ""}}); + +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + string_to_spv("conv_transpose_2d_f32", "conv_transpose_2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}, {"COOPMAT2", "1"}}, true, false, true); + string_to_spv("conv_transpose_2d_f16_f32", "conv_transpose_2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}, {"COOPMAT2", "1"}}, true, false, true); +#endif + string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}})); string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}})); string_to_spv("conv2d_dw_whcn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}})); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index b54a1a4e823f9..4f70865221839 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3969,6 +3969,10 @@ struct test_conv_transpose_2d : public test_case { return VARS_TO_STR3(ne_input, ne_kernel, stride); } + double max_nmse_err() override { + return 5e-4; // The default 1e-7 is too small for Vulkan. + } + test_conv_transpose_2d(std::array ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1] std::array ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1] int stride = 1) From 5c888ca29e119c20592e3580ec7e1fed3612f970 Mon Sep 17 00:00:00 2001 From: Shin-myoung-serp Date: Tue, 16 Sep 2025 16:39:51 +0900 Subject: [PATCH 2/7] Vulkan: fix typo in conv_transpose_2d shader(s0mp, s0L, s1mp, s1L) --- ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_2d_mm.comp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_2d_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_2d_mm.comp index 87e9142f39534..a1cca7f36b9c5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_2d_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_2d_mm.comp @@ -273,8 +273,8 @@ void main() { uint32_t H_idx_x_s1 = OH_idx - KH_idx_b * p.d1 + p.p1; uint32_t W_idx_x_s0 = OW_idx - KW_idx_b * p.d0 + p.p0; - uint32_t H_idx = fastdiv(H_idx_x_s1, p.s0mp, p.s0L); - uint32_t W_idx = fastdiv(W_idx_x_s0, p.s1mp, p.s1L); + uint32_t H_idx = fastdiv(H_idx_x_s1, p.s1mp, p.s1L); + uint32_t W_idx = fastdiv(W_idx_x_s0, p.s0mp, p.s0L); float val; if (CRS_idx_b >= CRS || NPQ_idx >= NPQ || int32_t(H_idx) < 0 || H_idx >= p.H || int32_t(W_idx) < 0 || W_idx >= p.W || From e029cda87f1c304008d46f189ff03f2444211723 Mon Sep 17 00:00:00 2001 From: Shin-myoung-serp Date: Tue, 16 Sep 2025 16:55:00 +0900 Subject: [PATCH 3/7] Vulkan: fix incorrect indentation in conv_transpose_2d shader --- ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_2d_mm.comp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_2d_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_2d_mm.comp index a1cca7f36b9c5..0efd47c78575b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_2d_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_2d_mm.comp @@ -227,7 +227,7 @@ void main() { uint32_t B_ly = r_offset + Ar; uint32_t B_lx = Ac; uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/ - float val; + float val; if (K_idx >= K || CRS_idx_a >= CRS) { val = 0.0; } else { From f5ae689945fc3d3b0d15c01dc023487c936ad51d Mon Sep 17 00:00:00 2001 From: Shin-myoung-serp Date: Wed, 17 Sep 2025 12:51:38 +0900 Subject: [PATCH 4/7] Vulkan: add checking the push constants size limit and reuse conv2d_mm.comp for conv_transpose_2d operation --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 26 +- .../ggml-vulkan/vulkan-shaders/conv2d_mm.comp | 38 +- .../vulkan-shaders/conv_transpose_2d_mm.comp | 339 ------------------ .../vulkan-shaders/vulkan-shaders-gen.cpp | 37 +- 4 files changed, 63 insertions(+), 377 deletions(-) delete mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_2d_mm.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 218861475ecd6..c409f061edd33 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3603,26 +3603,25 @@ static void ggml_vk_load_shaders(vk_device& device) { device, device->pipeline_##name##type_suffix[s], #name #type_suffix, \ name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \ sizeof(vk_op_##name##_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); +#define CREATE_CONVS(spv_suffix) \ + CREATE_CONV(conv2d, _f32, spv_suffix) \ + CREATE_CONV(conv2d, _f16_f32, spv_suffix) \ + if (device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_conv_transpose_2d_push_constants)) { \ + CREATE_CONV(conv_transpose_2d, _f32, spv_suffix) \ + CREATE_CONV(conv_transpose_2d, _f16_f32, spv_suffix) \ + } #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) if (device->coopmat2) { - CREATE_CONV(conv2d, _f32, _cm2) - CREATE_CONV(conv2d, _f16_f32, _cm2) - CREATE_CONV(conv_transpose_2d, _f32, _cm2) - CREATE_CONV(conv_transpose_2d, _f16_f32, _cm2) + CREATE_CONVS(_cm2) } else #endif if (conv2d_UNROLL) { - CREATE_CONV(conv2d, _f32, _unroll) - CREATE_CONV(conv2d, _f16_f32, _unroll) - CREATE_CONV(conv_transpose_2d, _f32, _unroll) - CREATE_CONV(conv_transpose_2d, _f16_f32, _unroll) + CREATE_CONVS(_unroll) } else { - CREATE_CONV(conv2d, _f32, ) - CREATE_CONV(conv2d, _f16_f32, ) - CREATE_CONV(conv_transpose_2d, _f32, ) - CREATE_CONV(conv_transpose_2d, _f16_f32, ) + CREATE_CONVS( ) } #undef CREATE_CONV +#undef CREATE_CONVS } ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); @@ -12722,6 +12721,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm // Op is disabled for Apple because it segfaults at pipeline create time on MoltenVK ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; const vk_device& device = ggml_vk_get_device(ctx->device); + if (op->op == GGML_OP_CONV_TRANSPOSE_2D && !device->pipeline_conv_transpose_2d_f32[0]) { + return false; + } // Channel-contiguous format is not supported yet. return ((op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && op->src[1]->type == GGML_TYPE_F32 && diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp index 86bafba4a4398..23ea3d0794198 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp @@ -16,7 +16,7 @@ // shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j layout(binding = 0) readonly buffer A { A_TYPE knl_data[]; -}; // src0 - kernel: [KW, KH, Cin, Cout] +}; // src0 - kernel: [KW, KH, Cin, Cout] for conv_2d, [KW, KH, Cout, Cin] for conv_transposed_2d layout(binding = 1) readonly buffer B { B_TYPE src_data[]; @@ -66,6 +66,10 @@ layout(push_constant) uniform parameter { uint32_t KWKHmp; uint32_t KWKHL; uint32_t OWmp; uint32_t OWL; uint32_t OWOHmp; uint32_t OWOHL; +#ifdef TRANSPOSE + uint32_t s0mp; uint32_t s0L; + uint32_t s1mp; uint32_t s1L; +#endif } p; @@ -225,10 +229,16 @@ void main() { uint32_t B_ly = r_offset + Ar; uint32_t B_lx = Ac; uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/ - uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1); - float val = knl_data[knl_idx]; + float val; if (K_idx >= K || CRS_idx_a >= CRS) { val = 0.0; + } else { +#ifdef TRANSPOSE + uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + K_idx * p.nb02 + Cin_idx_a * p.nb03, K * CRS - 1); +#else + uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1); +#endif + val = knl_data[knl_idx]; } Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val); } @@ -267,13 +277,27 @@ void main() { KW_idx_b = CRS_remainder - KH_idx_b * p.KW; #endif +#ifdef TRANSPOSE + uint32_t H_idx_x_s1 = OH_idx - KH_idx_b * p.d1 + p.p1; + uint32_t W_idx_x_s0 = OW_idx - KW_idx_b * p.d0 + p.p0; + uint32_t H_idx = fastdiv(H_idx_x_s1, p.s1mp, p.s1L); + uint32_t W_idx = fastdiv(W_idx_x_s0, p.s0mp, p.s0L); +#else uint32_t H_idx = OH_idx * p.s1 + KH_idx_b * p.d1 - p.p1; uint32_t W_idx = OW_idx * p.s0 + KW_idx_b * p.d0 - p.p0; - uint32_t src_idx = - min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1); - float val = src_data[src_idx]; - if (CRS_idx_b >= CRS || NPQ_idx >= NPQ || H_idx < 0 || H_idx >= p.H || W_idx < 0 || W_idx >= p.W) { +#endif + float val; + if (CRS_idx_b >= CRS || NPQ_idx >= NPQ + || int32_t(H_idx) < 0 || H_idx >= p.H || int32_t(W_idx) < 0 || W_idx >= p.W +#ifdef TRANSPOSE + || (H_idx_x_s1 - H_idx * p.s1 != 0) || (W_idx_x_s0 - W_idx * p.s0 != 0) +#endif + ) { val = 0.0; + } else { + uint32_t src_idx = + min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1); + val = src_data[src_idx]; } Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_2d_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_2d_mm.comp deleted file mode 100644 index 0efd47c78575b..0000000000000 --- a/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_2d_mm.comp +++ /dev/null @@ -1,339 +0,0 @@ -#version 450 - -#extension GL_EXT_control_flow_attributes : enable -#ifdef COOPMAT2 -#extension GL_NV_cooperative_matrix2 : enable -#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require -#extension GL_KHR_memory_scope_semantics : enable -#endif - -#ifdef USE_COLLECTIVES -# extension GL_KHR_shader_subgroup_shuffle : enable -#endif - -#include "types.comp" - -// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j -layout(binding = 0) readonly buffer A { - A_TYPE knl_data[]; -}; // src0 - kernel: [KW, KH, Cout, Cin] - -layout(binding = 1) readonly buffer B { - B_TYPE src_data[]; -}; // src1 - input: [W, H, Cin, N] -- channel_first format - -layout(binding = 2) writeonly buffer D { - D_TYPE dst_data[]; -}; // dst - result: [OW, OH, Cout, N] - -layout(push_constant) uniform parameter { - // I/O channels, batch size - uint32_t Cout; - uint32_t Cin; - uint32_t N; - - // Tensor spatial sizes: kernel, input, output - uint32_t KW; - uint32_t KH; - uint32_t W; - uint32_t H; - uint32_t OW; - uint32_t OH; - - // Parameters: stride, padding, dilation - 0=y, 1=x - uint32_t s0; - uint32_t s1; - uint32_t p0; - uint32_t p1; - uint32_t d0; - uint32_t d1; - - // Strides in elements - uint32_t nb01; - uint32_t nb02; - uint32_t nb03; - - uint32_t nb11; - uint32_t nb12; - uint32_t nb13; - - uint32_t nb1; - uint32_t nb2; - uint32_t nb3; - - // fastdiv helper values - uint32_t KWmp; uint32_t KWL; - uint32_t KWKHmp; uint32_t KWKHL; - uint32_t OWmp; uint32_t OWL; - uint32_t OWOHmp; uint32_t OWOHL; - uint32_t s0mp; uint32_t s0L; - uint32_t s1mp; uint32_t s1L; -} - -p; - -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -// Blocktile sizes -layout(constant_id = 1) const uint BS_K = 128; -layout(constant_id = 2) const uint BS_CRS = 16; -layout(constant_id = 3) const uint BS_NPQ = 128; -// Thread-tile sizes -layout(constant_id = 4) const uint TS_K = 8; -layout(constant_id = 5) const uint use_collectives = 1; -layout(constant_id = 6) const uint SHMEM_PAD = 4; - -uint32_t tid = gl_LocalInvocationID.x; -const uint32_t WG_SIZE = gl_WorkGroupSize.x; - -uint splitWork(uint work_size, uint block_size) { - return (block_size + work_size - 1) / block_size; -} - -uint32_t K = p.Cout; -uint32_t CRS = p.Cin * p.KH * p.KW; -uint32_t NPQ = p.N * p.OH * p.OW; - -uint32_t n_elems_out = K * NPQ; - -// Number of blocktiles per input -uint32_t NB_CRS = splitWork(CRS, BS_CRS); - -#ifdef COOPMAT2 -#define SHMEM_TYPE float16_t -#else -#define SHMEM_TYPE float -#endif - -const uint32_t Ash_stride = BS_CRS + SHMEM_PAD; -const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD; - -const uint32_t Ash_numel = BS_K * BS_CRS; -const uint32_t Bsh_numel = BS_CRS * BS_NPQ; - -const uint32_t Ash_len = BS_K * Ash_stride; -const uint32_t Bsh_len = BS_CRS * Bsh_stride; - -shared SHMEM_TYPE Ash[Ash_len]; // K x CRS -shared SHMEM_TYPE Bsh[Bsh_len]; // CRS x NPQ - -// Threadtile sizes -const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K; - -// Number of threadtiles per blocktile -const uint32_t NT_K = BS_K / TS_K; -const uint32_t NT_NPQ = BS_NPQ / TS_NPQ; - -/* -Compute -KxCRS @ CRSxNPQ = K x NPQ -K=Cout -C=Cin -R,S=KH,KW -P,Q=OH,OW -*/ - -uint32_t B_idx_K = gl_WorkGroupID.x; -uint32_t B_idx_NPQ = gl_WorkGroupID.y; - -uint32_t T_y = tid / NT_NPQ; -uint32_t T_x = tid % NT_NPQ; - -uint32_t Ar = tid / BS_CRS; -uint32_t Ac = tid % BS_CRS; -const uint32_t ArpWg = WG_SIZE / BS_CRS; - -uint32_t Br = tid / BS_NPQ; -uint32_t Bc = tid % BS_NPQ; -const uint32_t BrpWg = WG_SIZE / BS_NPQ; - -// see init_fastdiv_values in ggml-vulkan.cpp -uint fastdiv(uint n, uint mp, uint L) { - uint msbs, lsbs; - // msbs = mulhi(n, mp) - umulExtended(n, mp, msbs, lsbs); - return (msbs + n) >> L; -} - -#ifdef COOPMAT2 -#define ACC_TYPE float16_t - -ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem) -{ - uint32_t K_idx = B_idx_K * BS_K + r; - uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + c; - uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW; - uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW; - uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW; - uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3; - if (K_idx < K && NPQ_idx < NPQ) { - dst_data[dst_idx] = D_TYPE(elem); - } - return elem; -} -#endif - -void main() { -#ifdef COOPMAT2 - coopmat matC; - matC = coopmat(0.0); -#else - float regC[TS_K][TS_NPQ]; - for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { - for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { - regC[T_ly][T_lx] = 0.0; - } - } -#endif - /* Advance block in CRS dim */ - for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) { - uint32_t CRS_idx_a; - uint32_t Cin_idx_a; - uint32_t KH_idx_a; - uint32_t KW_idx_a; - -#ifdef USE_COLLECTIVES - uint32_t cached_CRS_idx; - uint32_t cached_Cin_idx; - uint32_t cached_KH_idx; - uint32_t cached_KW_idx; - if (use_collectives == 1) { - cached_CRS_idx = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID; - cached_Cin_idx = fastdiv(cached_CRS_idx, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); - uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx * p.KW * p.KH); - cached_KH_idx = fastdiv(cached_CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; - cached_KW_idx = cached_CRS_remainder - cached_KH_idx * p.KW; - - CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac); - Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac); - KH_idx_a = subgroupShuffle(cached_KH_idx, Ac); - KW_idx_a = subgroupShuffle(cached_KW_idx, Ac); - } else { - CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A) - Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); - uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH; - KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; - KW_idx_a = CRS_remainder - KH_idx_a * p.KW; - } -#else - CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A) - Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); / (p.KW * p.KH); - CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH; - KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; - KW_idx_a = CRS_remainder - KH_idx_a * p.KW; -#endif - - /* Load kernel to A_block: (BS_K x BS_CRS)*/ - for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) { - uint32_t B_ly = r_offset + Ar; - uint32_t B_lx = Ac; - uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/ - float val; - if (K_idx >= K || CRS_idx_a >= CRS) { - val = 0.0; - } else { - uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + K_idx * p.nb02 + Cin_idx_a * p.nb03, K * CRS - 1); - val = knl_data[knl_idx]; - } - Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val); - } - /* Load input to B_block: (BS_CRS x BS_NPQ) */ - UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) { - uint32_t B_ly = r_offset + Br; /* Row index of B block */ - uint32_t B_lx = Bc; - uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */ - uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW; - uint32_t NPQ_remainder = NPQ_idx - N_idx * p.OH * p.OW; - uint32_t OH_idx = fastdiv(NPQ_remainder, p.OWmp, p.OWL); // divide by p.OW; - uint32_t OW_idx = NPQ_remainder - OH_idx * p.OW; - - uint32_t CRS_idx_b; - uint32_t Cin_idx_b; - uint32_t KH_idx_b; - uint32_t KW_idx_b; -#ifdef USE_COLLECTIVES - if (use_collectives == 1) { - CRS_idx_b = subgroupShuffle(cached_CRS_idx, r_offset + Br); - Cin_idx_b = subgroupShuffle(cached_Cin_idx, r_offset + Br); - KH_idx_b = subgroupShuffle(cached_KH_idx, r_offset + Br); - KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br); - } else { - CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */ - Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); - uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH; - KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; - KW_idx_b = CRS_remainder - KH_idx_b * p.KW; - } -#else - CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */ - Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); - uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH; - KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; - KW_idx_b = CRS_remainder - KH_idx_b * p.KW; -#endif - - uint32_t H_idx_x_s1 = OH_idx - KH_idx_b * p.d1 + p.p1; - uint32_t W_idx_x_s0 = OW_idx - KW_idx_b * p.d0 + p.p0; - uint32_t H_idx = fastdiv(H_idx_x_s1, p.s1mp, p.s1L); - uint32_t W_idx = fastdiv(W_idx_x_s0, p.s0mp, p.s0L); - float val; - if (CRS_idx_b >= CRS || NPQ_idx >= NPQ || - int32_t(H_idx) < 0 || H_idx >= p.H || int32_t(W_idx) < 0 || W_idx >= p.W || - (H_idx_x_s1 - H_idx * p.s1 != 0) || (W_idx_x_s0 - W_idx * p.s0 != 0)) { - val = 0.0; - } else { - uint32_t src_idx = - min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1); - val = src_data[src_idx]; - } - Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val); - } - barrier(); -#ifdef COOPMAT2 - coopmat matA; - coopmat matB; - - coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor); - coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor); - matC = coopMatMulAdd(matA, matB, matC); -#else - if (T_y * TS_K < K) { - UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) { - float regA[TS_K]; - float regB[TS_NPQ]; - for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { - regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx]; - } - for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { - regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx]; - } - for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { - for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { - regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]); - } - } - } - } -#endif - barrier(); - } - /* Save C* */ -#ifdef COOPMAT2 - coopMatPerElementNV(matC, matC, perElemOpStore); -#else - if (T_y * TS_K < K) { - for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { - for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { - uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly; - uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx; - uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW; - uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW; - uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW; - uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3; - if (K_idx < K && NPQ_idx < NPQ) { - dst_data[dst_idx] = regC[T_ly][T_lx]; - } - } - } - } -#endif -} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 30b8e8a776c2f..f9c5ef52b8b33 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -793,27 +793,26 @@ void process_shaders() { string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); - string_to_spv("conv2d_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}}); - string_to_spv("conv2d_f16_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}}); - - string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", ""}}); - string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", ""}}); - + for (auto transpose : {false, true}) { + for (auto unroll : {false, true}) { + for (auto a_f16 : {false, true}) { + std::map defines = { + {"A_TYPE", a_f16 ? "float16_t" : "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, + {"USE_COLLECTIVES", "1"}, {"UNROLL", unroll ? "[[unroll]]" : ""}, + }; + if (transpose) defines["TRANSPOSE"] = "1"; + std::string name = std::string(transpose ? "conv_transpose_2d": "conv2d") + + (a_f16 ? "_f16" : "") + "_f32"; + string_to_spv(name + (unroll ? "_unroll" : ""), "conv2d_mm.comp", defines); #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) - string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}, {"COOPMAT2", "1"}}, true, false, true); - string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}, {"COOPMAT2", "1"}}, true, false, true); -#endif - - string_to_spv("conv_transpose_2d_f32_unroll", "conv_transpose_2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}}); - string_to_spv("conv_transpose_2d_f16_f32_unroll", "conv_transpose_2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}}); - - string_to_spv("conv_transpose_2d_f32", "conv_transpose_2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", ""}}); - string_to_spv("conv_transpose_2d_f16_f32", "conv_transpose_2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", ""}}); - -#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) - string_to_spv("conv_transpose_2d_f32", "conv_transpose_2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}, {"COOPMAT2", "1"}}, true, false, true); - string_to_spv("conv_transpose_2d_f16_f32", "conv_transpose_2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}, {"COOPMAT2", "1"}}, true, false, true); + if (unroll) { + defines["COOPMAT2"] = "1"; + string_to_spv(name, "conv2d_mm.comp", defines, true, false, true); + } #endif + } + } + } string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}})); string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}})); From 12aaeaea1885e2b1fd2ee9f9f61467a73e10632f Mon Sep 17 00:00:00 2001 From: Shin-myoung-serp Date: Fri, 19 Sep 2025 10:56:39 +0900 Subject: [PATCH 5/7] Vulkan: revert the order of the index calculation and bound check in conv_2d shader --- .../ggml-vulkan/vulkan-shaders/conv2d_mm.comp | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp index 23ea3d0794198..75caf16c8d1b1 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp @@ -229,16 +229,14 @@ void main() { uint32_t B_ly = r_offset + Ar; uint32_t B_lx = Ac; uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/ - float val; - if (K_idx >= K || CRS_idx_a >= CRS) { - val = 0.0; - } else { #ifdef TRANSPOSE - uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + K_idx * p.nb02 + Cin_idx_a * p.nb03, K * CRS - 1); + uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + K_idx * p.nb02 + Cin_idx_a * p.nb03, K * CRS - 1); #else - uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1); + uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1); #endif - val = knl_data[knl_idx]; + float val = knl_data[knl_idx]; + if (K_idx >= K || CRS_idx_a >= CRS) { + val = 0.0; } Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val); } @@ -286,7 +284,9 @@ void main() { uint32_t H_idx = OH_idx * p.s1 + KH_idx_b * p.d1 - p.p1; uint32_t W_idx = OW_idx * p.s0 + KW_idx_b * p.d0 - p.p0; #endif - float val; + uint32_t src_idx = + min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1); + float val = src_data[src_idx]; if (CRS_idx_b >= CRS || NPQ_idx >= NPQ || int32_t(H_idx) < 0 || H_idx >= p.H || int32_t(W_idx) < 0 || W_idx >= p.W #ifdef TRANSPOSE @@ -294,10 +294,6 @@ void main() { #endif ) { val = 0.0; - } else { - uint32_t src_idx = - min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1); - val = src_data[src_idx]; } Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val); } From 7e24d171f63f6e8af8f861ef608d41e09a50f875 Mon Sep 17 00:00:00 2001 From: Shin-myoung-serp Date: Fri, 19 Sep 2025 11:20:12 +0900 Subject: [PATCH 6/7] Vulkan: explicity check push constants limit in supports_op() for conv_transpose_2d operation. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index c409f061edd33..c269033e5373b 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -12721,7 +12721,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm // Op is disabled for Apple because it segfaults at pipeline create time on MoltenVK ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; const vk_device& device = ggml_vk_get_device(ctx->device); - if (op->op == GGML_OP_CONV_TRANSPOSE_2D && !device->pipeline_conv_transpose_2d_f32[0]) { + if (op->op == GGML_OP_CONV_TRANSPOSE_2D && + device->properties.limits.maxPushConstantsSize < sizeof(vk_op_conv_transpose_2d_push_constants)) { return false; } // Channel-contiguous format is not supported yet. From 55b3fb5d636ce550d35e37082f89905844542488 Mon Sep 17 00:00:00 2001 From: Shin-myoung-serp Date: Sun, 21 Sep 2025 21:38:07 +0900 Subject: [PATCH 7/7] Vulkan: remove unnecessary lower bound checks for H/W_idx in the conv_2d shader. --- ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp index 75caf16c8d1b1..44a64ddc80f62 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp @@ -288,7 +288,7 @@ void main() { min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1); float val = src_data[src_idx]; if (CRS_idx_b >= CRS || NPQ_idx >= NPQ - || int32_t(H_idx) < 0 || H_idx >= p.H || int32_t(W_idx) < 0 || W_idx >= p.W + || H_idx >= p.H || W_idx >= p.W // Lower bound checks aren't necessary. (idx >= 0x80000000 for such case) #ifdef TRANSPOSE || (H_idx_x_s1 - H_idx * p.s1 != 0) || (W_idx_x_s0 - W_idx * p.s0 != 0) #endif