2929
3030#include " ggml-vulkan-shaders.hpp"
3131
32- #define VK_API_VERSION VK_API_VERSION_1_2
33-
3432#define CEIL_DIV (M, N ) (((M) + (N)-1 ) / (N))
3533
3634#define VK_VENDOR_ID_AMD 0x1002
@@ -1614,11 +1612,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
16141612 CREATE_MM (PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
16151613 CREATE_MM (PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
16161614
1617- CREATE_MM (pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3 )
1618- CREATE_MM (pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3 )
1619-
16201615 CREATE_MM2 (pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3 )
1621- CREATE_MM2 (pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3 )
16221616 CREATE_MM (pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc , matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 )
16231617 CREATE_MM (pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc , matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 )
16241618 CREATE_MM (pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc , matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 )
@@ -1631,21 +1625,18 @@ static void ggml_vk_load_shaders(vk_device& device) {
16311625 CREATE_MM (pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc , matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3 )
16321626 CREATE_MM (pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc , matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 )
16331627
1634- CREATE_MM (pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4 )
16351628 CREATE_MM2 (pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4 )
1636- CREATE_MM2 (pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4 )
1637-
1638- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc , matmul_id_q4_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1639- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc , matmul_id_q4_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1640- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc , matmul_id_q5_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1641- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc , matmul_id_q5_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1642- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc , matmul_id_q8_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1643- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc , matmul_id_q2_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1644- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc , matmul_id_q3_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1645- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc , matmul_id_q4_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1646- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc , matmul_id_q5_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1647- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc , matmul_id_q6_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1648- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc , matmul_id_iq4_nl_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1629+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc , matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1630+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc , matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1631+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc , matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1632+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc , matmul_id_q5_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1633+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc , matmul_id_q8_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1634+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc , matmul_id_q2_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1635+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc , matmul_id_q3_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1636+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc , matmul_id_q4_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1637+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc , matmul_id_q5_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1638+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc , matmul_id_q6_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
1639+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc , matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4 )
16491640#undef CREATE_MM
16501641#undef CREATE_MM2
16511642 } else
@@ -2287,6 +2278,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
22872278 }
22882279#endif
22892280
2281+ VkPhysicalDeviceMaintenance4Features maint4_features {};
2282+ maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES;
2283+ if (maintenance4_support) {
2284+ last_struct->pNext = (VkBaseOutStructure *)&maint4_features;
2285+ last_struct = (VkBaseOutStructure *)&maint4_features;
2286+ device_extensions.push_back (" VK_KHR_maintenance4" );
2287+ }
2288+
22902289 vkGetPhysicalDeviceFeatures2 (device->physical_device , &device_features2);
22912290
22922291 device->fp16 = device->fp16 && vk12_features.shaderFloat16 ;
@@ -2662,7 +2661,14 @@ void ggml_vk_instance_init() {
26622661
26632662 vk_instance_initialized = true ;
26642663
2665- vk::ApplicationInfo app_info{ " ggml-vulkan" , 1 , nullptr , 0 , VK_API_VERSION };
2664+ uint32_t api_version = vk::enumerateInstanceVersion ();
2665+
2666+ if (api_version < VK_API_VERSION_1_2) {
2667+ std::cerr << " ggml_vulkan: Error: Vulkan 1.2 required." << std::endl;
2668+ GGML_ABORT (" fatal error" );
2669+ }
2670+
2671+ vk::ApplicationInfo app_info{ " ggml-vulkan" , 1 , nullptr , 0 , api_version };
26662672
26672673 const std::vector<vk::ExtensionProperties> instance_extensions = vk::enumerateInstanceExtensionProperties ();
26682674 const bool validation_ext = ggml_vk_instance_validation_ext_available (instance_extensions);
@@ -2972,7 +2978,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
29722978 }
29732979 }
29742980
2975- GGML_ASSERT (src1_type == GGML_TYPE_F32);
2981+ GGML_ASSERT (src1_type == GGML_TYPE_F32 || (ctx-> device -> coopmat2 && src1_type == GGML_TYPE_F16) );
29762982
29772983 switch (src0_type) {
29782984 case GGML_TYPE_Q4_0:
@@ -3812,8 +3818,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
38123818 src1_uma = d_Qy != nullptr ;
38133819 }
38143820
3815- const bool x_non_contig = !ggml_vk_dim01_contiguous (src0);
3816- // Reformat and convert to fp16 if src1 is non-contiguous, or for coopmat2 for better perf
3821+ // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf
3822+ const bool x_non_contig = (ctx->device ->coopmat2 && src0->type == GGML_TYPE_F32) ||
3823+ !ggml_vk_dim01_contiguous (src0);
38173824 const bool y_non_contig = (ctx->device ->coopmat2 && src1->type == GGML_TYPE_F32) ||
38183825 !ggml_vk_dim01_contiguous (src1);
38193826
@@ -4393,8 +4400,11 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
43934400 ids_uma = d_ids != nullptr ;
43944401 }
43954402
4396- const bool x_non_contig = !ggml_vk_dim01_contiguous (src0);
4397- const bool y_non_contig = !ggml_vk_dim01_contiguous (src1);
4403+ // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf
4404+ const bool x_non_contig = (ctx->device ->coopmat2 && src0->type == GGML_TYPE_F32) ||
4405+ !ggml_vk_dim01_contiguous (src0);
4406+ const bool y_non_contig = (ctx->device ->coopmat2 && src1->type == GGML_TYPE_F32) ||
4407+ !ggml_vk_dim01_contiguous (src1);
43984408
43994409 const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
44004410
@@ -4404,7 +4414,8 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
44044414 const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
44054415
44064416 if (qx_needs_dequant) {
4407- GGML_ABORT (" fatal error" );
4417+ // Fall back to dequant + f16 mulmat
4418+ mmp = ggml_vk_get_mul_mat_mat_id_pipeline (ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params [0 ]);
44084419 }
44094420
44104421 // Not implemented
0 commit comments