@@ -149,6 +149,66 @@ static void ggml_vk_destroy_buffer(vk_buffer& buf);
149149
150150static constexpr uint32_t mul_mat_vec_max_cols = 8 ;
151151
152+ enum vk_device_architecture {
153+ OTHER,
154+ AMD_GCN,
155+ AMD_RDNA1,
156+ AMD_RDNA2,
157+ AMD_RDNA3,
158+ };
159+
160+ static vk_device_architecture get_device_architecture (const vk::PhysicalDevice& device) {
161+ vk::PhysicalDeviceProperties props = device.getProperties ();
162+
163+ if (props.vendorID == VK_VENDOR_ID_AMD) {
164+ const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties ();
165+
166+ bool amd_shader_core_properties = false ;
167+ bool integer_dot_product = false ;
168+ bool subgroup_size_control = false ;
169+
170+ for (const auto & properties : ext_props) {
171+ if (strcmp (" VK_AMD_shader_core_properties" , properties.extensionName ) == 0 ) {
172+ amd_shader_core_properties = true ;
173+ } else if (strcmp (" VK_KHR_shader_integer_dot_product" , properties.extensionName ) == 0 ) {
174+ integer_dot_product = true ;
175+ } else if (strcmp (" VK_EXT_subgroup_size_control" , properties.extensionName ) == 0 ) {
176+ subgroup_size_control = true ;
177+ }
178+ }
179+
180+ if (!amd_shader_core_properties || !integer_dot_product || !subgroup_size_control) {
181+ return vk_device_architecture::OTHER;
182+ }
183+
184+ vk::PhysicalDeviceProperties2 props2;
185+ vk::PhysicalDeviceShaderCorePropertiesAMD shader_core_props_amd;
186+ vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR integer_dot_props;
187+ vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
188+
189+ props2.pNext = &shader_core_props_amd;
190+ shader_core_props_amd.pNext = &integer_dot_props;
191+ integer_dot_props.pNext = &subgroup_size_control_props;
192+
193+ device.getProperties2 (&props2);
194+
195+ if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 64 ) {
196+ return vk_device_architecture::AMD_GCN;
197+ }
198+ if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 32 ) {
199+ // RDNA
200+ if (shader_core_props_amd.wavefrontsPerSimd == 20 ) {
201+ return vk_device_architecture::AMD_RDNA1;
202+ }
203+ if (integer_dot_props.integerDotProduct4x8BitPackedMixedSignednessAccelerated ) {
204+ return vk_device_architecture::AMD_RDNA3;
205+ }
206+ return vk_device_architecture::AMD_RDNA2;
207+ }
208+ }
209+ return vk_device_architecture::OTHER;
210+ }
211+
152212struct vk_device_struct {
153213 std::mutex mutex;
154214
@@ -161,6 +221,7 @@ struct vk_device_struct {
161221 bool pipeline_robustness;
162222 vk::Device device;
163223 uint32_t vendor_id;
224+ vk_device_architecture architecture;
164225 vk_queue compute_queue;
165226 vk_queue transfer_queue;
166227 bool single_queue;
@@ -2296,7 +2357,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
22962357 device->need_compiles = false ;
22972358}
22982359
2299- static bool ggml_vk_khr_cooperative_matrix_support (const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props);
2360+ static bool ggml_vk_khr_cooperative_matrix_support (const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch );
23002361
23012362static vk_device ggml_vk_get_device (size_t idx) {
23022363 VK_LOG_DEBUG (" ggml_vk_get_device(" << idx << " )" );
@@ -2325,6 +2386,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
23252386 device->physical_device = physical_devices[dev_num];
23262387 const std::vector<vk::ExtensionProperties> ext_props = device->physical_device .enumerateDeviceExtensionProperties ();
23272388
2389+ device->architecture = get_device_architecture (device->physical_device );
2390+
23282391 const char * GGML_VK_PREFER_HOST_MEMORY = getenv (" GGML_VK_PREFER_HOST_MEMORY" );
23292392 device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr ;
23302393
@@ -2337,7 +2400,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
23372400 bool coopmat2_support = false ;
23382401 device->coopmat_support = false ;
23392402
2340- // Check if maintenance4 is supported
23412403 for (const auto & properties : ext_props) {
23422404 if (strcmp (" VK_KHR_maintenance4" , properties.extensionName ) == 0 ) {
23432405 maintenance4_support = true ;
@@ -2450,7 +2512,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
24502512
24512513 device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
24522514
2453- if (!ggml_vk_khr_cooperative_matrix_support (device->properties , driver_props)) {
2515+ if (!ggml_vk_khr_cooperative_matrix_support (device->properties , driver_props, device-> architecture )) {
24542516 device->coopmat_support = false ;
24552517 }
24562518
@@ -2856,7 +2918,9 @@ static void ggml_vk_print_gpu_info(size_t idx) {
28562918 }
28572919 }
28582920
2859- if (!ggml_vk_khr_cooperative_matrix_support (props2.properties , driver_props)) {
2921+ const vk_device_architecture device_architecture = get_device_architecture (physical_device);
2922+
2923+ if (!ggml_vk_khr_cooperative_matrix_support (props2.properties , driver_props, device_architecture)) {
28602924 coopmat_support = false ;
28612925 }
28622926
@@ -8877,18 +8941,15 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve
88778941 UNUSED (instance_extensions);
88788942}
88798943
8880- static bool ggml_vk_khr_cooperative_matrix_support (const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props) {
8944+ static bool ggml_vk_khr_cooperative_matrix_support (const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch ) {
88818945 switch (props.vendorID ) {
88828946 case VK_VENDOR_ID_INTEL:
88838947 // Intel drivers don't support coopmat properly yet
88848948 return false ;
88858949 case VK_VENDOR_ID_AMD:
88868950 if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) {
88878951 // Workaround for AMD proprietary driver reporting support on all GPUs
8888- const std::string name = props.deviceName ;
8889- return name.rfind (" AMD Radeon RX 7" , 0 ) == 0 || name.rfind (" AMD Radeon(TM) RX 7" , 0 ) == 0 || // RDNA 3 consumer GPUs
8890- name.rfind (" AMD Radeon PRO W7" , 0 ) == 0 || name.rfind (" AMD Radeon(TM) PRO W7" , 0 ) == 0 || // RDNA 3 workstation GPUs
8891- name.rfind (" AMD Radeon 7" , 0 ) == 0 || name.rfind (" AMD Radeon(TM) 7" , 0 ) == 0 ; // RDNA 3 APUs
8952+ return arch == vk_device_architecture::AMD_RDNA3;
88928953 }
88938954 return true ;
88948955 default :
0 commit comments