2929
3030#include  " ggml-vulkan-shaders.hpp" 
3131
32- #define  VK_API_VERSION  VK_API_VERSION_1_2
33- 
3432#define  CEIL_DIV (M, N ) (((M) + (N)-1 ) / (N))
3533
3634#define  VK_VENDOR_ID_AMD  0x1002 
@@ -1611,11 +1609,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
16111609        CREATE_MM (PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT)   \
16121610        CREATE_MM (PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT)   \
16131611
1614-         CREATE_MM (pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3 )
1615-         CREATE_MM (pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3 )
1616- 
16171612        CREATE_MM2 (pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3 )
1618-         CREATE_MM2 (pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3 )
16191613        CREATE_MM (pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc , matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 )
16201614        CREATE_MM (pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc , matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 )
16211615        CREATE_MM (pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc , matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 )
@@ -1628,21 +1622,18 @@ static void ggml_vk_load_shaders(vk_device& device) {
16281622        CREATE_MM (pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc , matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3 )
16291623        CREATE_MM (pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc , matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 )
16301624
1631-         CREATE_MM (pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4 )
16321625        CREATE_MM2 (pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4 )
1633-         CREATE_MM2 (pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4 )
1634- 
1635-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc , matmul_id_q4_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1636-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc , matmul_id_q4_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1637-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc , matmul_id_q5_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1638-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc , matmul_id_q5_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1639-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc , matmul_id_q8_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1640-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc , matmul_id_q2_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1641-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc , matmul_id_q3_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1642-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc , matmul_id_q4_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1643-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc , matmul_id_q5_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1644-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc , matmul_id_q6_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1645-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc , matmul_id_iq4_nl_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1626+         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc , matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1627+         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc , matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1628+         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc , matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1629+         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc , matmul_id_q5_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1630+         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc , matmul_id_q8_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1631+         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc , matmul_id_q2_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1632+         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc , matmul_id_q3_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1633+         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc , matmul_id_q4_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1634+         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc , matmul_id_q5_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1635+         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc , matmul_id_q6_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1636+         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc , matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
16461637#undef  CREATE_MM
16471638#undef  CREATE_MM2
16481639    } else 
@@ -2284,6 +2275,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
22842275        }
22852276#endif 
22862277
2278+         VkPhysicalDeviceMaintenance4Features maint4_features {};
2279+         maint4_features.sType  = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES;
2280+         if  (maintenance4_support) {
2281+             last_struct->pNext  = (VkBaseOutStructure *)&maint4_features;
2282+             last_struct = (VkBaseOutStructure *)&maint4_features;
2283+             device_extensions.push_back (" VK_KHR_maintenance4"  );
2284+         }
2285+ 
22872286        vkGetPhysicalDeviceFeatures2 (device->physical_device , &device_features2);
22882287
22892288        device->fp16  = device->fp16  && vk12_features.shaderFloat16 ;
@@ -2659,7 +2658,14 @@ void ggml_vk_instance_init() {
26592658
26602659    vk_instance_initialized = true ;
26612660
2662-     vk::ApplicationInfo app_info{ " ggml-vulkan"  , 1 , nullptr , 0 , VK_API_VERSION };
2661+     uint32_t  api_version = vk::enumerateInstanceVersion ();
2662+ 
2663+     if  (api_version < VK_API_VERSION_1_2) {
2664+         std::cerr << " ggml_vulkan: Error: Vulkan 1.2 required."   << std::endl;
2665+         GGML_ABORT (" fatal error"  );
2666+     }
2667+ 
2668+     vk::ApplicationInfo app_info{ " ggml-vulkan"  , 1 , nullptr , 0 , api_version };
26632669
26642670    const  std::vector<vk::ExtensionProperties> instance_extensions = vk::enumerateInstanceExtensionProperties ();
26652671    const  bool  validation_ext = ggml_vk_instance_validation_ext_available (instance_extensions);
@@ -2969,7 +2975,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
29692975        }
29702976    }
29712977
2972-     GGML_ASSERT (src1_type == GGML_TYPE_F32);
2978+     GGML_ASSERT (src1_type == GGML_TYPE_F32 || (ctx-> device -> coopmat2  && src1_type == GGML_TYPE_F16) );
29732979
29742980    switch  (src0_type) {
29752981        case  GGML_TYPE_Q4_0:
@@ -3809,8 +3815,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
38093815        src1_uma = d_Qy != nullptr ;
38103816    }
38113817
3812-     const  bool  x_non_contig = !ggml_vk_dim01_contiguous (src0);
3813-     //  Reformat and convert to fp16 if src1 is non-contiguous, or for coopmat2 for better perf
3818+     //  Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf
3819+     const  bool  x_non_contig = (ctx->device ->coopmat2  && src0->type  == GGML_TYPE_F32) ||
3820+                               !ggml_vk_dim01_contiguous (src0);
38143821    const  bool  y_non_contig = (ctx->device ->coopmat2  && src1->type  == GGML_TYPE_F32) ||
38153822                              !ggml_vk_dim01_contiguous (src1);
38163823
@@ -4390,8 +4397,11 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
43904397        ids_uma = d_ids != nullptr ;
43914398    }
43924399
4393-     const  bool  x_non_contig = !ggml_vk_dim01_contiguous (src0);
4394-     const  bool  y_non_contig = !ggml_vk_dim01_contiguous (src1);
4400+     //  Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf
4401+     const  bool  x_non_contig = (ctx->device ->coopmat2  && src0->type  == GGML_TYPE_F32) ||                      
4402+                               !ggml_vk_dim01_contiguous (src0);
4403+     const  bool  y_non_contig = (ctx->device ->coopmat2  && src1->type  == GGML_TYPE_F32) ||
4404+                               !ggml_vk_dim01_contiguous (src1);
43954405
43964406    const  bool  y_f32_kernel = src1->type  == GGML_TYPE_F32 && !y_non_contig;
43974407
@@ -4401,7 +4411,8 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
44014411    const  bool  qy_needs_dequant = (src1->type  != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
44024412
44034413    if  (qx_needs_dequant) {
4404-         GGML_ABORT (" fatal error"  );
4414+         //  Fall back to dequant + f16 mulmat
4415+         mmp = ggml_vk_get_mul_mat_mat_id_pipeline (ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params [0 ]);
44054416    }
44064417
44074418    //  Not implemented
0 commit comments