2929
3030#include " ggml-vulkan-shaders.hpp"
3131
32- #define VK_API_VERSION VK_API_VERSION_1_2
33-
3432#define CEIL_DIV (M, N ) (((M) + (N)-1 ) / (N))
3533
3634#define VK_VENDOR_ID_AMD 0x1002
@@ -1611,11 +1609,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
16111609 CREATE_MM (PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
16121610 CREATE_MM (PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
16131611
1614- CREATE_MM (pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3 )
1615- CREATE_MM (pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3 )
1616-
16171612 CREATE_MM2 (pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3 )
1618- CREATE_MM2 (pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3 )
16191613 CREATE_MM (pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc , matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 )
16201614 CREATE_MM (pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc , matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 )
16211615 CREATE_MM (pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc , matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 )
@@ -1628,21 +1622,18 @@ static void ggml_vk_load_shaders(vk_device& device) {
16281622 CREATE_MM (pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc , matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3 )
16291623 CREATE_MM (pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc , matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 )
16301624
1631- CREATE_MM (pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4 )
16321625 CREATE_MM2 (pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4 )
1633- CREATE_MM2 (pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4 )
1634-
1635- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc , matmul_id_q4_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1636- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc , matmul_id_q4_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1637- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc , matmul_id_q5_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1638- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc , matmul_id_q5_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1639- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc , matmul_id_q8_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1640- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc , matmul_id_q2_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1641- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc , matmul_id_q3_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1642- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc , matmul_id_q4_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1643- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc , matmul_id_q5_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1644- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc , matmul_id_q6_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1645- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc , matmul_id_iq4_nl_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1626+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc , matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1627+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc , matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1628+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc , matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1629+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc , matmul_id_q5_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1630+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc , matmul_id_q8_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1631+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc , matmul_id_q2_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1632+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc , matmul_id_q3_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1633+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc , matmul_id_q4_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1634+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc , matmul_id_q5_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1635+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc , matmul_id_q6_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1636+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc , matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
16461637#undef CREATE_MM
16471638#undef CREATE_MM2
16481639 } else
@@ -2284,6 +2275,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
22842275 }
22852276#endif
22862277
2278+ VkPhysicalDeviceMaintenance4Features maint4_features {};
2279+ maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES;
2280+ if (maintenance4_support) {
2281+ last_struct->pNext = (VkBaseOutStructure *)&maint4_features;
2282+ last_struct = (VkBaseOutStructure *)&maint4_features;
2283+ device_extensions.push_back (" VK_KHR_maintenance4" );
2284+ }
2285+
22872286 vkGetPhysicalDeviceFeatures2 (device->physical_device , &device_features2);
22882287
22892288 device->fp16 = device->fp16 && vk12_features.shaderFloat16 ;
@@ -2659,7 +2658,14 @@ void ggml_vk_instance_init() {
26592658
26602659 vk_instance_initialized = true ;
26612660
2662- vk::ApplicationInfo app_info{ " ggml-vulkan" , 1 , nullptr , 0 , VK_API_VERSION };
2661+ uint32_t api_version = vk::enumerateInstanceVersion ();
2662+
2663+ if (api_version < VK_API_VERSION_1_2) {
2664+ std::cerr << " ggml_vulkan: Error: Vulkan 1.2 required." << std::endl;
2665+ GGML_ABORT (" fatal error" );
2666+ }
2667+
2668+ vk::ApplicationInfo app_info{ " ggml-vulkan" , 1 , nullptr , 0 , api_version };
26632669
26642670 const std::vector<vk::ExtensionProperties> instance_extensions = vk::enumerateInstanceExtensionProperties ();
26652671 const bool validation_ext = ggml_vk_instance_validation_ext_available (instance_extensions);
@@ -2969,7 +2975,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
29692975 }
29702976 }
29712977
2972- GGML_ASSERT (src1_type == GGML_TYPE_F32);
2978+ GGML_ASSERT (src1_type == GGML_TYPE_F32 || (ctx-> device -> coopmat2 && src1_type == GGML_TYPE_F16) );
29732979
29742980 switch (src0_type) {
29752981 case GGML_TYPE_Q4_0:
@@ -3809,8 +3815,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
38093815 src1_uma = d_Qy != nullptr ;
38103816 }
38113817
3812- const bool x_non_contig = !ggml_vk_dim01_contiguous (src0);
3813- // Reformat and convert to fp16 if src1 is non-contiguous, or for coopmat2 for better perf
3818+ // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf
3819+ const bool x_non_contig = (ctx->device ->coopmat2 && src0->type == GGML_TYPE_F32) ||
3820+ !ggml_vk_dim01_contiguous (src0);
38143821 const bool y_non_contig = (ctx->device ->coopmat2 && src1->type == GGML_TYPE_F32) ||
38153822 !ggml_vk_dim01_contiguous (src1);
38163823
@@ -4390,8 +4397,11 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
43904397 ids_uma = d_ids != nullptr ;
43914398 }
43924399
4393- const bool x_non_contig = !ggml_vk_dim01_contiguous (src0);
4394- const bool y_non_contig = !ggml_vk_dim01_contiguous (src1);
4400+ // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf
4401+ const bool x_non_contig = (ctx->device ->coopmat2 && src0->type == GGML_TYPE_F32) ||
4402+ !ggml_vk_dim01_contiguous (src0);
4403+ const bool y_non_contig = (ctx->device ->coopmat2 && src1->type == GGML_TYPE_F32) ||
4404+ !ggml_vk_dim01_contiguous (src1);
43954405
43964406 const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
43974407
@@ -4401,7 +4411,8 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
44014411 const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
44024412
44034413 if (qx_needs_dequant) {
4404- GGML_ABORT (" fatal error" );
4414+ // Fall back to dequant + f16 mulmat
4415+ mmp = ggml_vk_get_mul_mat_mat_id_pipeline (ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params [0 ]);
44054416 }
44064417
44074418 // Not implemented
0 commit comments