@@ -3603,26 +3603,25 @@ static void ggml_vk_load_shaders(vk_device& device) {
36033603 device, device->pipeline_##name##type_suffix[s], #name #type_suffix, \
36043604 name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \
36053605 sizeof(vk_op_##name##_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3606+ #define CREATE_CONVS(spv_suffix) \
3607+ CREATE_CONV(conv2d, _f32, spv_suffix) \
3608+ CREATE_CONV(conv2d, _f16_f32, spv_suffix) \
3609+ if (device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_conv_transpose_2d_push_constants)) { \
3610+ CREATE_CONV(conv_transpose_2d, _f32, spv_suffix) \
3611+ CREATE_CONV(conv_transpose_2d, _f16_f32, spv_suffix) \
3612+ }
36063613#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
36073614 if (device->coopmat2) {
3608- CREATE_CONV(conv2d, _f32, _cm2)
3609- CREATE_CONV(conv2d, _f16_f32, _cm2)
3610- CREATE_CONV(conv_transpose_2d, _f32, _cm2)
3611- CREATE_CONV(conv_transpose_2d, _f16_f32, _cm2)
3615+ CREATE_CONVS(_cm2)
36123616 } else
36133617#endif
36143618 if (conv2d_UNROLL) {
3615- CREATE_CONV(conv2d, _f32, _unroll)
3616- CREATE_CONV(conv2d, _f16_f32, _unroll)
3617- CREATE_CONV(conv_transpose_2d, _f32, _unroll)
3618- CREATE_CONV(conv_transpose_2d, _f16_f32, _unroll)
3619+ CREATE_CONVS(_unroll)
36193620 } else {
3620- CREATE_CONV(conv2d, _f32, )
3621- CREATE_CONV(conv2d, _f16_f32, )
3622- CREATE_CONV(conv_transpose_2d, _f32, )
3623- CREATE_CONV(conv_transpose_2d, _f16_f32, )
3621+ CREATE_CONVS( )
36243622 }
36253623#undef CREATE_CONV
3624+ #undef CREATE_CONVS
36263625 }
36273626
36283627 ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
@@ -12722,6 +12721,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1272212721 // Op is disabled for Apple because it segfaults at pipeline create time on MoltenVK
1272312722 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
1272412723 const vk_device& device = ggml_vk_get_device(ctx->device);
12724+ if (op->op == GGML_OP_CONV_TRANSPOSE_2D && !device->pipeline_conv_transpose_2d_f32[0]) {
12725+ return false;
12726+ }
1272512727 // Channel-contiguous format is not supported yet.
1272612728 return ((op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
1272712729 op->src[1]->type == GGML_TYPE_F32 &&
0 commit comments