@@ -2040,6 +2040,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
20402040    std::cerr << " Done!" 
20412041}
20422042
2043+ static  bool  ggml_vk_khr_cooperative_matrix_support (const  vk::PhysicalDeviceProperties& props, const  vk::PhysicalDeviceDriverProperties& driver_props);
2044+ 
20432045static  vk_device ggml_vk_get_device (size_t  idx) {
20442046    VK_LOG_DEBUG (" ggml_vk_get_device(" " )" 
20452047
@@ -2175,9 +2177,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
21752177
21762178        device->fp16  = !force_disable_f16 && fp16_storage && fp16_compute;
21772179
2178-         if  (device->vendor_id  == VK_VENDOR_ID_INTEL || (device->vendor_id  == VK_VENDOR_ID_AMD && (driver_props.driverID  == vk::DriverId::eAmdProprietary || driver_props.driverID  == vk::DriverId::eAmdOpenSource))) {
2179-             //  Intel drivers don't support coopmat properly yet
2180-             //  Only RADV supports coopmat properly on AMD
2180+         if  (!ggml_vk_khr_cooperative_matrix_support (device->properties , driver_props)) {
21812181            device->coopmat_support  = false ;
21822182        }
21832183
@@ -2515,7 +2515,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
25152515    return  vk_instance.devices [idx];
25162516}
25172517
2518- 
25192518static  void  ggml_vk_print_gpu_info (size_t  idx) {
25202519    GGML_ASSERT (idx < vk_instance.device_indices .size ());
25212520    size_t  dev_num = vk_instance.device_indices [idx];
@@ -2565,9 +2564,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
25652564        }
25662565    }
25672566
2568-     if  (props2.properties .vendorID  == VK_VENDOR_ID_INTEL || (props2.properties .vendorID  == VK_VENDOR_ID_AMD && (driver_props.driverID  == vk::DriverId::eAmdProprietary || driver_props.driverID  == vk::DriverId::eAmdOpenSource))) {
2569-         //  Intel drivers don't support coopmat properly yet
2570-         //  Only RADV supports coopmat properly on AMD
2567+     if  (!ggml_vk_khr_cooperative_matrix_support (props2.properties , driver_props)) {
25712568        coopmat_support = false ;
25722569    }
25732570
@@ -8088,6 +8085,25 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve
80888085    UNUSED (instance_extensions);
80898086}
80908087
8088+ static  bool  ggml_vk_khr_cooperative_matrix_support (const  vk::PhysicalDeviceProperties& props, const  vk::PhysicalDeviceDriverProperties& driver_props) {
8089+     switch  (props.vendorID ) {
8090+     case  VK_VENDOR_ID_INTEL:
8091+         //  Intel drivers don't support coopmat properly yet
8092+         return  false ;
8093+     case  VK_VENDOR_ID_AMD:
8094+         if  (driver_props.driverID  == vk::DriverId::eAmdProprietary || driver_props.driverID  == vk::DriverId::eAmdOpenSource) {
8095+             //  Workaround for AMD proprietary driver reporting support on all GPUs
8096+             const  std::string name = props.deviceName ;
8097+             return  name.rfind (" AMD Radeon RX 7" 0 ) == 0    || name.rfind (" AMD Radeon(TM) RX 7" 0 ) == 0    || //  RDNA 3 consumer GPUs
8098+                    name.rfind (" AMD Radeon PRO W7" 0 ) == 0  || name.rfind (" AMD Radeon(TM) PRO W7" 0 ) == 0  || //  RDNA 3 workstation GPUs
8099+                    name.rfind (" AMD Radeon 7" 0 ) == 0       || name.rfind (" AMD Radeon(TM) 7" 0 ) == 0 ;        //  RDNA 3 APUs
8100+         }
8101+         return  true ;
8102+     default :
8103+         return  true ;
8104+     }
8105+ }
8106+ 
80918107//  checks
80928108
80938109#ifdef  GGML_VULKAN_CHECK_RESULTS
0 commit comments