|
10 | 10 | #include "common.hpp"
|
11 | 11 | #include "platform.hpp"
|
12 | 12 |
|
| 13 | +#include <algorithm> |
13 | 14 | #include <array>
|
14 | 15 | #include <cassert>
|
15 | 16 |
|
@@ -938,17 +939,25 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
|
938 | 939 | return UR_RESULT_SUCCESS;
|
939 | 940 | }
|
940 | 941 | case UR_DEVICE_INFO_SUB_GROUP_SIZES_INTEL: {
|
941 |
| - // Have to convert size_t to uint32_t |
942 |
| - size_t SubGroupSizesSize = 0; |
943 |
| - CL_RETURN_ON_FAILURE( |
944 |
| - clGetDeviceInfo(cl_adapter::cast<cl_device_id>(hDevice), CLPropName, 0, |
945 |
| - nullptr, &SubGroupSizesSize)); |
946 |
| - std::vector<size_t> SubGroupSizes(SubGroupSizesSize / sizeof(size_t)); |
947 |
| - CL_RETURN_ON_FAILURE( |
948 |
| - clGetDeviceInfo(cl_adapter::cast<cl_device_id>(hDevice), CLPropName, |
949 |
| - SubGroupSizesSize, SubGroupSizes.data(), nullptr)); |
950 |
| - return ReturnValue.template operator()<uint32_t>(SubGroupSizes.data(), |
951 |
| - SubGroupSizes.size()); |
| 942 | + size_t ExtSize = 0; |
| 943 | + urDeviceGetInfo(hDevice, UR_DEVICE_INFO_EXTENSIONS, 0, nullptr, &ExtSize); |
| 944 | + std::string ExtStr(ExtSize, 0); |
| 945 | + urDeviceGetInfo(hDevice, UR_DEVICE_INFO_EXTENSIONS, ExtSize, ExtStr.data(), nullptr); |
| 946 | + if (ExtStr.find("cl_intel_required_subgroup_size")!=std::string::npos) { |
| 947 | + // Have to convert size_t to uint32_t |
| 948 | + size_t SubGroupSizesSize = 0; |
| 949 | + CL_RETURN_ON_FAILURE( |
| 950 | + clGetDeviceInfo(cl_adapter::cast<cl_device_id>(hDevice), CLPropName, 0, |
| 951 | + nullptr, &SubGroupSizesSize)); |
| 952 | + std::vector<size_t> SubGroupSizes(SubGroupSizesSize / sizeof(size_t)); |
| 953 | + CL_RETURN_ON_FAILURE( |
| 954 | + clGetDeviceInfo(cl_adapter::cast<cl_device_id>(hDevice), CLPropName, |
| 955 | + SubGroupSizesSize, SubGroupSizes.data(), nullptr)); |
| 956 | + return ReturnValue.template operator()<uint32_t>(SubGroupSizes.data(), |
| 957 | + SubGroupSizes.size()); |
| 958 | + } else { |
| 959 | + return ReturnValue.template operator()<uint32_t>(std::data({1}),1); |
| 960 | + } |
952 | 961 | }
|
953 | 962 | case UR_DEVICE_INFO_EXTENSIONS: {
|
954 | 963 | cl_device_id Dev = cl_adapter::cast<cl_device_id>(hDevice);
|
|
0 commit comments