Skip to content

Commit 949cbd4

Browse files
committed
Add default value for sub group sizes query
Add a check for `cl_intel_required_subgroup_size` extension. if the device does not have it a default value for sub group sizes list is returned.
1 parent 658393f commit 949cbd4

File tree

1 file changed

+20
-11
lines changed

1 file changed

+20
-11
lines changed

source/adapters/opencl/device.cpp

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "common.hpp"
1111
#include "platform.hpp"
1212

13+
#include <algorithm>
1314
#include <array>
1415
#include <cassert>
1516

@@ -938,17 +939,25 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
938939
return UR_RESULT_SUCCESS;
939940
}
940941
case UR_DEVICE_INFO_SUB_GROUP_SIZES_INTEL: {
941-
// Have to convert size_t to uint32_t
942-
size_t SubGroupSizesSize = 0;
943-
CL_RETURN_ON_FAILURE(
944-
clGetDeviceInfo(cl_adapter::cast<cl_device_id>(hDevice), CLPropName, 0,
945-
nullptr, &SubGroupSizesSize));
946-
std::vector<size_t> SubGroupSizes(SubGroupSizesSize / sizeof(size_t));
947-
CL_RETURN_ON_FAILURE(
948-
clGetDeviceInfo(cl_adapter::cast<cl_device_id>(hDevice), CLPropName,
949-
SubGroupSizesSize, SubGroupSizes.data(), nullptr));
950-
return ReturnValue.template operator()<uint32_t>(SubGroupSizes.data(),
951-
SubGroupSizes.size());
942+
size_t ExtSize = 0;
943+
urDeviceGetInfo(hDevice, UR_DEVICE_INFO_EXTENSIONS, 0, nullptr, &ExtSize);
944+
std::string ExtStr(ExtSize, 0);
945+
urDeviceGetInfo(hDevice, UR_DEVICE_INFO_EXTENSIONS, ExtSize, ExtStr.data(), nullptr);
946+
if (ExtStr.find("cl_intel_required_subgroup_size")!=std::string::npos) {
947+
// Have to convert size_t to uint32_t
948+
size_t SubGroupSizesSize = 0;
949+
CL_RETURN_ON_FAILURE(
950+
clGetDeviceInfo(cl_adapter::cast<cl_device_id>(hDevice), CLPropName, 0,
951+
nullptr, &SubGroupSizesSize));
952+
std::vector<size_t> SubGroupSizes(SubGroupSizesSize / sizeof(size_t));
953+
CL_RETURN_ON_FAILURE(
954+
clGetDeviceInfo(cl_adapter::cast<cl_device_id>(hDevice), CLPropName,
955+
SubGroupSizesSize, SubGroupSizes.data(), nullptr));
956+
return ReturnValue.template operator()<uint32_t>(SubGroupSizes.data(),
957+
SubGroupSizes.size());
958+
} else {
959+
return ReturnValue.template operator()<uint32_t>(std::data({1}),1);
960+
}
952961
}
953962
case UR_DEVICE_INFO_EXTENSIONS: {
954963
cl_device_id Dev = cl_adapter::cast<cl_device_id>(hDevice);

0 commit comments

Comments
 (0)