@@ -2040,6 +2040,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
20402040 std::cerr << " Done!" << std::endl;
20412041}
20422042
2043+ static bool ggml_vk_khr_cooperative_matrix_support (const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props);
2044+
20432045static vk_device ggml_vk_get_device (size_t idx) {
20442046 VK_LOG_DEBUG (" ggml_vk_get_device(" << idx << " )" );
20452047
@@ -2175,9 +2177,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
21752177
21762178 device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
21772179
2178- if (device->vendor_id == VK_VENDOR_ID_INTEL || (device->vendor_id == VK_VENDOR_ID_AMD && (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource))) {
2179- // Intel drivers don't support coopmat properly yet
2180- // Only RADV supports coopmat properly on AMD
2180+ if (!ggml_vk_khr_cooperative_matrix_support (device->properties , driver_props)) {
21812181 device->coopmat_support = false ;
21822182 }
21832183
@@ -2515,7 +2515,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
25152515 return vk_instance.devices [idx];
25162516}
25172517
2518-
25192518static void ggml_vk_print_gpu_info (size_t idx) {
25202519 GGML_ASSERT (idx < vk_instance.device_indices .size ());
25212520 size_t dev_num = vk_instance.device_indices [idx];
@@ -2565,9 +2564,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
25652564 }
25662565 }
25672566
2568- if (props2.properties .vendorID == VK_VENDOR_ID_INTEL || (props2.properties .vendorID == VK_VENDOR_ID_AMD && (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource))) {
2569- // Intel drivers don't support coopmat properly yet
2570- // Only RADV supports coopmat properly on AMD
2567+ if (!ggml_vk_khr_cooperative_matrix_support (props2.properties , driver_props)) {
25712568 coopmat_support = false ;
25722569 }
25732570
@@ -8088,6 +8085,25 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve
80888085 UNUSED (instance_extensions);
80898086}
80908087
8088+ static bool ggml_vk_khr_cooperative_matrix_support (const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props) {
8089+ switch (props.vendorID ) {
8090+ case VK_VENDOR_ID_INTEL:
8091+ // Intel drivers don't support coopmat properly yet
8092+ return false ;
8093+ case VK_VENDOR_ID_AMD:
8094+ if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) {
8095+ // Workaround for AMD proprietary driver reporting support on all GPUs
8096+ const std::string name = props.deviceName ;
8097+ return name.rfind (" AMD Radeon RX 7" , 0 ) == 0 || // RDNA 3 consumer GPUs
8098+ name.rfind (" AMD Radeon PRO W7" , 0 ) == 0 || // RDNA 3 workstation GPUs
8099+ name.rfind (" AMD Radeon 7" , 0 ) == 0 ; // RDNA 3 APUs
8100+ }
8101+ return true ;
8102+ default :
8103+ return true ;
8104+ }
8105+ }
8106+
80918107// checks
80928108
80938109#ifdef GGML_VULKAN_CHECK_RESULTS
0 commit comments