@@ -149,6 +149,66 @@ static void ggml_vk_destroy_buffer(vk_buffer& buf);
149149
150150static constexpr uint32_t mul_mat_vec_max_cols = 8 ;
151151
152+ enum vk_device_architecture {
153+ OTHER,
154+ AMD_GCN,
155+ AMD_RDNA1,
156+ AMD_RDNA2,
157+ AMD_RDNA3,
158+ };
159+
160+ static vk_device_architecture get_device_architecture (const vk::PhysicalDevice& device) {
161+ vk::PhysicalDeviceProperties props = device.getProperties ();
162+
163+ if (props.vendorID == VK_VENDOR_ID_AMD) {
164+ const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties ();
165+
166+ bool amd_shader_core_properties = false ;
167+ bool integer_dot_product = false ;
168+ bool subgroup_size_control = false ;
169+
170+ for (const auto & properties : ext_props) {
171+ if (strcmp (" VK_AMD_shader_core_properties" , properties.extensionName ) == 0 ) {
172+ amd_shader_core_properties = true ;
173+ } else if (strcmp (" VK_KHR_shader_integer_dot_product" , properties.extensionName ) == 0 ) {
174+ integer_dot_product = true ;
175+ } else if (strcmp (" VK_EXT_subgroup_size_control" , properties.extensionName ) == 0 ) {
176+ subgroup_size_control = true ;
177+ }
178+ }
179+
180+ if (!amd_shader_core_properties || !integer_dot_product || !subgroup_size_control) {
181+ return vk_device_architecture::OTHER;
182+ }
183+
184+ vk::PhysicalDeviceProperties2 props2;
185+ vk::PhysicalDeviceShaderCorePropertiesAMD shader_core_props_amd;
186+ vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR integer_dot_props;
187+ vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
188+
189+ props2.pNext = &shader_core_props_amd;
190+ shader_core_props_amd.pNext = &integer_dot_props;
191+ integer_dot_props.pNext = &subgroup_size_control_props;
192+
193+ device.getProperties2 (&props2);
194+
195+ if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 64 ) {
196+ return vk_device_architecture::AMD_GCN;
197+ }
198+ if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 32 ) {
199+ // RDNA
200+ if (shader_core_props_amd.wavefrontsPerSimd == 20 ) {
201+ return vk_device_architecture::AMD_RDNA1;
202+ }
203+ if (integer_dot_props.integerDotProduct4x8BitPackedMixedSignednessAccelerated ) {
204+ return vk_device_architecture::AMD_RDNA3;
205+ }
206+ return vk_device_architecture::AMD_RDNA2;
207+ }
208+ }
209+ return vk_device_architecture::OTHER;
210+ }
211+
152212struct vk_device_struct {
153213 std::mutex mutex;
154214
@@ -161,6 +221,7 @@ struct vk_device_struct {
161221 bool pipeline_robustness;
162222 vk::Device device;
163223 uint32_t vendor_id;
224+ vk_device_architecture architecture;
164225 vk_queue compute_queue;
165226 vk_queue transfer_queue;
166227 bool single_queue;
@@ -2219,7 +2280,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
22192280 device->need_compiles = false ;
22202281}
22212282
2222- static bool ggml_vk_khr_cooperative_matrix_support (const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props);
2283+ static bool ggml_vk_khr_cooperative_matrix_support (const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch );
22232284
22242285static vk_device ggml_vk_get_device (size_t idx) {
22252286 VK_LOG_DEBUG (" ggml_vk_get_device(" << idx << " )" );
@@ -2248,6 +2309,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
22482309 device->physical_device = physical_devices[dev_num];
22492310 const std::vector<vk::ExtensionProperties> ext_props = device->physical_device .enumerateDeviceExtensionProperties ();
22502311
2312+ device->architecture = get_device_architecture (device->physical_device );
2313+
22512314 bool fp16_storage = false ;
22522315 bool fp16_compute = false ;
22532316 bool maintenance4_support = false ;
@@ -2257,7 +2320,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
22572320 bool coopmat2_support = false ;
22582321 device->coopmat_support = false ;
22592322
2260- // Check if maintenance4 is supported
22612323 for (const auto & properties : ext_props) {
22622324 if (strcmp (" VK_KHR_maintenance4" , properties.extensionName ) == 0 ) {
22632325 maintenance4_support = true ;
@@ -2370,7 +2432,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
23702432
23712433 device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
23722434
2373- if (!ggml_vk_khr_cooperative_matrix_support (device->properties , driver_props)) {
2435+ if (!ggml_vk_khr_cooperative_matrix_support (device->properties , driver_props, device-> architecture )) {
23742436 device->coopmat_support = false ;
23752437 }
23762438
@@ -2776,7 +2838,9 @@ static void ggml_vk_print_gpu_info(size_t idx) {
27762838 }
27772839 }
27782840
2779- if (!ggml_vk_khr_cooperative_matrix_support (props2.properties , driver_props)) {
2841+ const vk_device_architecture device_architecture = get_device_architecture (physical_device);
2842+
2843+ if (!ggml_vk_khr_cooperative_matrix_support (props2.properties , driver_props, device_architecture)) {
27802844 coopmat_support = false ;
27812845 }
27822846
@@ -8435,18 +8499,15 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve
84358499 UNUSED (instance_extensions);
84368500}
84378501
8438- static bool ggml_vk_khr_cooperative_matrix_support (const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props) {
8502+ static bool ggml_vk_khr_cooperative_matrix_support (const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch ) {
84398503 switch (props.vendorID ) {
84408504 case VK_VENDOR_ID_INTEL:
84418505 // Intel drivers don't support coopmat properly yet
84428506 return false ;
84438507 case VK_VENDOR_ID_AMD:
84448508 if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) {
84458509 // Workaround for AMD proprietary driver reporting support on all GPUs
8446- const std::string name = props.deviceName ;
8447- return name.rfind (" AMD Radeon RX 7" , 0 ) == 0 || name.rfind (" AMD Radeon(TM) RX 7" , 0 ) == 0 || // RDNA 3 consumer GPUs
8448- name.rfind (" AMD Radeon PRO W7" , 0 ) == 0 || name.rfind (" AMD Radeon(TM) PRO W7" , 0 ) == 0 || // RDNA 3 workstation GPUs
8449- name.rfind (" AMD Radeon 7" , 0 ) == 0 || name.rfind (" AMD Radeon(TM) 7" , 0 ) == 0 ; // RDNA 3 APUs
8510+ return arch == vk_device_architecture::AMD_RDNA3;
84508511 }
84518512 return true ;
84528513 default :
0 commit comments