|
10 | 10 | #include "common.hpp" |
11 | 11 | #include "platform.hpp" |
12 | 12 |
|
| 13 | +#include <algorithm> |
13 | 14 | #include <array> |
14 | 15 | #include <cassert> |
15 | 16 |
|
@@ -938,17 +939,25 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice, |
938 | 939 | return UR_RESULT_SUCCESS; |
939 | 940 | } |
940 | 941 | case UR_DEVICE_INFO_SUB_GROUP_SIZES_INTEL: { |
941 | | - // Have to convert size_t to uint32_t |
942 | | - size_t SubGroupSizesSize = 0; |
943 | | - CL_RETURN_ON_FAILURE( |
944 | | - clGetDeviceInfo(cl_adapter::cast<cl_device_id>(hDevice), CLPropName, 0, |
945 | | - nullptr, &SubGroupSizesSize)); |
946 | | - std::vector<size_t> SubGroupSizes(SubGroupSizesSize / sizeof(size_t)); |
947 | | - CL_RETURN_ON_FAILURE( |
948 | | - clGetDeviceInfo(cl_adapter::cast<cl_device_id>(hDevice), CLPropName, |
949 | | - SubGroupSizesSize, SubGroupSizes.data(), nullptr)); |
950 | | - return ReturnValue.template operator()<uint32_t>(SubGroupSizes.data(), |
951 | | - SubGroupSizes.size()); |
| 942 | + size_t ExtSize = 0; |
| 943 | + urDeviceGetInfo(hDevice, UR_DEVICE_INFO_EXTENSIONS, 0, nullptr, &ExtSize); |
| 944 | + std::string ExtStr(ExtSize, 0); |
| 945 | + urDeviceGetInfo(hDevice, UR_DEVICE_INFO_EXTENSIONS, ExtSize, ExtStr.data(), nullptr); |
| 946 | + if (ExtStr.find("cl_intel_required_subgroup_size")!=std::string::npos) { |
| 947 | + // Have to convert size_t to uint32_t |
| 948 | + size_t SubGroupSizesSize = 0; |
| 949 | + CL_RETURN_ON_FAILURE( |
| 950 | + clGetDeviceInfo(cl_adapter::cast<cl_device_id>(hDevice), CLPropName, 0, |
| 951 | + nullptr, &SubGroupSizesSize)); |
| 952 | + std::vector<size_t> SubGroupSizes(SubGroupSizesSize / sizeof(size_t)); |
| 953 | + CL_RETURN_ON_FAILURE( |
| 954 | + clGetDeviceInfo(cl_adapter::cast<cl_device_id>(hDevice), CLPropName, |
| 955 | + SubGroupSizesSize, SubGroupSizes.data(), nullptr)); |
| 956 | + return ReturnValue.template operator()<uint32_t>(SubGroupSizes.data(), |
| 957 | + SubGroupSizes.size()); |
| 958 | + } else { |
| 959 | + return ReturnValue.template operator()<uint32_t>(std::data({1}),1); |
| 960 | + } |
952 | 961 | } |
953 | 962 | case UR_DEVICE_INFO_EXTENSIONS: { |
954 | 963 | cl_device_id Dev = cl_adapter::cast<cl_device_id>(hDevice); |
|
0 commit comments