2929
3030#include  " ggml-vulkan-shaders.hpp" 
3131
32- #define  VK_API_VERSION  VK_API_VERSION_1_2
33- 
3432#define  CEIL_DIV (M, N ) (((M) + (N)-1 ) / (N))
3533
3634#define  VK_VENDOR_ID_AMD  0x1002 
@@ -1614,11 +1612,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
16141612        CREATE_MM (PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT)   \
16151613        CREATE_MM (PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT)   \
16161614
1617-         CREATE_MM (pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3 )
1618-         CREATE_MM (pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3 )
1619- 
16201615        CREATE_MM2 (pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3 )
1621-         CREATE_MM2 (pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3 )
16221616        CREATE_MM (pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc , matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 )
16231617        CREATE_MM (pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc , matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 )
16241618        CREATE_MM (pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc , matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 )
@@ -1631,21 +1625,18 @@ static void ggml_vk_load_shaders(vk_device& device) {
16311625        CREATE_MM (pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc , matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3 )
16321626        CREATE_MM (pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc , matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 )
16331627
1634-         CREATE_MM (pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4 )
16351628        CREATE_MM2 (pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4 )
1636-         CREATE_MM2 (pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4 )
1637- 
1638-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc , matmul_id_q4_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1639-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc , matmul_id_q4_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1640-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc , matmul_id_q5_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1641-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc , matmul_id_q5_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1642-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc , matmul_id_q8_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1643-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc , matmul_id_q2_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1644-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc , matmul_id_q3_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1645-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc , matmul_id_q4_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1646-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc , matmul_id_q5_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1647-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc , matmul_id_q6_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1648-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc , matmul_id_iq4_nl_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1629+         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc , matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1630+         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc , matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1631+         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc , matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1632+         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc , matmul_id_q5_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1633+         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc , matmul_id_q8_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1634+         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc , matmul_id_q2_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1635+         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc , matmul_id_q3_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1636+         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc , matmul_id_q4_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1637+         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc , matmul_id_q5_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1638+         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc , matmul_id_q6_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1639+         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc , matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
16491640#undef  CREATE_MM
16501641#undef  CREATE_MM2
16511642    } else 
@@ -2287,6 +2278,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
22872278        }
22882279#endif 
22892280
2281+         VkPhysicalDeviceMaintenance4Features maint4_features {};
2282+         maint4_features.sType  = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES;
2283+         if  (maintenance4_support) {
2284+             last_struct->pNext  = (VkBaseOutStructure *)&maint4_features;
2285+             last_struct = (VkBaseOutStructure *)&maint4_features;
2286+             device_extensions.push_back (" VK_KHR_maintenance4"  );
2287+         }
2288+ 
22902289        vkGetPhysicalDeviceFeatures2 (device->physical_device , &device_features2);
22912290
22922291        device->fp16  = device->fp16  && vk12_features.shaderFloat16 ;
@@ -2662,7 +2661,14 @@ void ggml_vk_instance_init() {
26622661
26632662    vk_instance_initialized = true ;
26642663
2665-     vk::ApplicationInfo app_info{ " ggml-vulkan"  , 1 , nullptr , 0 , VK_API_VERSION };
2664+     uint32_t  api_version = vk::enumerateInstanceVersion ();
2665+ 
2666+     if  (api_version < VK_API_VERSION_1_2) {
2667+         std::cerr << " ggml_vulkan: Error: Vulkan 1.2 required."   << std::endl;
2668+         GGML_ABORT (" fatal error"  );
2669+     }
2670+ 
2671+     vk::ApplicationInfo app_info{ " ggml-vulkan"  , 1 , nullptr , 0 , api_version };
26662672
26672673    const  std::vector<vk::ExtensionProperties> instance_extensions = vk::enumerateInstanceExtensionProperties ();
26682674    const  bool  validation_ext = ggml_vk_instance_validation_ext_available (instance_extensions);
@@ -2972,7 +2978,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
29722978        }
29732979    }
29742980
2975-     GGML_ASSERT (src1_type == GGML_TYPE_F32);
2981+     GGML_ASSERT (src1_type == GGML_TYPE_F32 || (ctx-> device -> coopmat2  && src1_type == GGML_TYPE_F16) );
29762982
29772983    switch  (src0_type) {
29782984        case  GGML_TYPE_Q4_0:
@@ -3812,8 +3818,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
38123818        src1_uma = d_Qy != nullptr ;
38133819    }
38143820
3815-     const  bool  x_non_contig = !ggml_vk_dim01_contiguous (src0);
3816-     //  Reformat and convert to fp16 if src1 is non-contiguous, or for coopmat2 for better perf
3821+     //  Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf
3822+     const  bool  x_non_contig = (ctx->device ->coopmat2  && src0->type  == GGML_TYPE_F32) ||
3823+                               !ggml_vk_dim01_contiguous (src0);
38173824    const  bool  y_non_contig = (ctx->device ->coopmat2  && src1->type  == GGML_TYPE_F32) ||
38183825                              !ggml_vk_dim01_contiguous (src1);
38193826
@@ -4393,8 +4400,11 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
43934400        ids_uma = d_ids != nullptr ;
43944401    }
43954402
4396-     const  bool  x_non_contig = !ggml_vk_dim01_contiguous (src0);
4397-     const  bool  y_non_contig = !ggml_vk_dim01_contiguous (src1);
4403+     //  Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf
4404+     const  bool  x_non_contig = (ctx->device ->coopmat2  && src0->type  == GGML_TYPE_F32) ||
4405+                               !ggml_vk_dim01_contiguous (src0);
4406+     const  bool  y_non_contig = (ctx->device ->coopmat2  && src1->type  == GGML_TYPE_F32) ||
4407+                               !ggml_vk_dim01_contiguous (src1);
43984408
43994409    const  bool  y_f32_kernel = src1->type  == GGML_TYPE_F32 && !y_non_contig;
44004410
@@ -4404,7 +4414,8 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
44044414    const  bool  qy_needs_dequant = (src1->type  != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
44054415
44064416    if  (qx_needs_dequant) {
4407-         GGML_ABORT (" fatal error"  );
4417+         //  Fall back to dequant + f16 mulmat
4418+         mmp = ggml_vk_get_mul_mat_mat_id_pipeline (ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params [0 ]);
44084419    }
44094420
44104421    //  Not implemented
0 commit comments