Skip to content

Commit 4dfc9b4

Browse files
authored
[UR][OpenCL] Move banned platforms filtering to UR (#20014)
Filter out banned OpenCl platforms earlier, in the UR OpenCL adapter instead of the SYCL RT.
1 parent 8031a7a commit 4dfc9b4

File tree

2 files changed

+75
-50
lines changed

2 files changed

+75
-50
lines changed

sycl/source/detail/platform_impl.cpp

Lines changed: 2 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -88,33 +88,6 @@ context_impl &platform_impl::khr_get_default_context() {
8888
return *It->second;
8989
}
9090

91-
static bool IsBannedPlatform(platform Platform) {
92-
// The NVIDIA OpenCL platform is currently not compatible with DPC++
93-
// since it is only 1.2 but gets selected by default in many systems
94-
// There is also no support on the PTX backend for OpenCL consumption,
95-
// and there have been some internal reports.
96-
// To avoid problems on default users and deployment of DPC++ on platforms
97-
// where CUDA is available, the OpenCL support is disabled.
98-
//
99-
// There is also no support for the AMD HSA backend for OpenCL consumption,
100-
// as well as reported problems with device queries, so AMD OpenCL support
101-
// is disabled as well.
102-
//
103-
auto IsMatchingOpenCL = [](platform Platform, const std::string_view name) {
104-
const bool HasNameMatch = Platform.get_info<info::platform::name>().find(
105-
name) != std::string::npos;
106-
const auto Backend = detail::getSyclObjImpl(Platform)->getBackend();
107-
const bool IsMatchingOCL = (HasNameMatch && Backend == backend::opencl);
108-
if (detail::ur::trace(detail::ur::TraceLevel::TRACE_ALL) && IsMatchingOCL) {
109-
std::cout << "SYCL_UR_TRACE: " << name
110-
<< " OpenCL platform found but is not compatible." << std::endl;
111-
}
112-
return IsMatchingOCL;
113-
};
114-
return IsMatchingOpenCL(Platform, "NVIDIA CUDA") ||
115-
IsMatchingOpenCL(Platform, "AMD Accelerated Parallel Processing");
116-
}
117-
11891
// Get the vector of platforms supported by a given UR adapter
11992
// replace uses of this with a helper in adapter object, the adapter
12093
// objects will own the ur adapter handles and they'll need to pass them to
@@ -132,25 +105,13 @@ std::vector<platform> platform_impl::getAdapterPlatforms(adapter_impl &Adapter,
132105
for (const auto &UrPlatform : UrPlatforms) {
133106
platform Platform = detail::createSyclObjFromImpl<platform>(
134107
getOrMakePlatformImpl(UrPlatform, Adapter));
135-
const bool IsBanned = IsBannedPlatform(Platform);
136-
bool HasAnyDevices = false;
137-
138-
// Platform.get_devices() increments the device count for the platform
139-
// and if the platform is banned (like OpenCL for AMD), it can cause
140-
// incorrect device numbering, when used with ONEAPI_DEVICE_SELECTOR.
141-
if (!IsBanned)
142-
HasAnyDevices = !Platform.get_devices(info::device_type::all).empty();
108+
bool HasAnyDevices = !Platform.get_devices(info::device_type::all).empty();
143109

144110
if (!Supported) {
145-
if (IsBanned || !HasAnyDevices) {
111+
if (!HasAnyDevices) {
146112
Platforms.push_back(std::move(Platform));
147113
}
148114
} else {
149-
if (IsBanned) {
150-
continue; // bail as early as possible, otherwise banned platforms may
151-
// mess up device counting
152-
}
153-
154115
// The SYCL spec says that a platform has one or more devices. ( SYCL
155116
// 2020 4.6.2 ) If we have an empty platform, we don't report it back
156117
// from platform::get_platforms().

unified-runtime/source/adapters/opencl/platform.cpp

Lines changed: 73 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,50 @@ static cl_int mapURPlatformInfoToCL(ur_platform_info_t URPropName) {
3030
}
3131
}
3232

33+
static bool isBannedOpenCLPlatform(cl_platform_id platform) {
34+
size_t nameSize = 0;
35+
cl_int res =
36+
clGetPlatformInfo(platform, CL_PLATFORM_NAME, 0, nullptr, &nameSize);
37+
if (res != CL_SUCCESS || nameSize == 0) {
38+
return false;
39+
}
40+
41+
std::string name(nameSize, '\0');
42+
res = clGetPlatformInfo(platform, CL_PLATFORM_NAME, nameSize, name.data(),
43+
nullptr);
44+
if (res != CL_SUCCESS) {
45+
return false;
46+
}
47+
48+
// The NVIDIA OpenCL platform is currently not compatible with DPC++
49+
// since it is only 1.2 but gets selected by default in many systems.
50+
// There is also no support on the PTX backend for OpenCL consumption.
51+
//
52+
// There is also no support for the AMD HSA backend for OpenCL consumption,
53+
// as well as reported problems with device queries, so AMD OpenCL support
54+
// is disabled as well.
55+
bool isBanned =
56+
name.find("NVIDIA CUDA") != std::string::npos ||
57+
name.find("AMD Accelerated Parallel Processing") != std::string::npos;
58+
59+
return isBanned;
60+
}
61+
62+
static bool isBannedOpenCLDevice(cl_device_id device) {
63+
cl_device_type deviceType = 0;
64+
cl_int res = clGetDeviceInfo(device, CL_DEVICE_TYPE, sizeof(cl_device_type),
65+
&deviceType, nullptr);
66+
if (res != CL_SUCCESS) {
67+
return false;
68+
}
69+
70+
// Filter out FPGA accelerator devices as their usage with OpenCL adapter is
71+
// deprecated
72+
bool isBanned = (deviceType & CL_DEVICE_TYPE_ACCELERATOR) != 0;
73+
74+
return isBanned;
75+
}
76+
3377
UR_DLLEXPORT ur_result_t UR_APICALL
3478
urPlatformGetInfo(ur_platform_handle_t hPlatform, ur_platform_info_t propName,
3579
size_t propSize, void *pPropValue, size_t *pSizeRet) {
@@ -102,14 +146,26 @@ urPlatformGet(ur_adapter_handle_t, uint32_t NumEntries,
102146
CLPlatforms.data(), nullptr);
103147
CL_RETURN_ON_FAILURE(Res);
104148

149+
// Filter out banned platforms
150+
std::vector<cl_platform_id> FilteredPlatforms;
151+
for (uint32_t i = 0; i < NumPlatforms; i++) {
152+
if (!isBannedOpenCLPlatform(CLPlatforms[i])) {
153+
FilteredPlatforms.push_back(CLPlatforms[i]);
154+
}
155+
}
156+
105157
try {
106-
for (uint32_t i = 0; i < NumPlatforms; i++) {
107-
auto URPlatform =
108-
std::make_unique<ur_platform_handle_t_>(CLPlatforms[i]);
158+
for (auto &Platform : FilteredPlatforms) {
159+
auto URPlatform = std::make_unique<ur_platform_handle_t_>(Platform);
109160
UR_RETURN_ON_FAILURE(URPlatform->InitDevices());
110-
Adapter->URPlatforms.emplace_back(URPlatform.release());
161+
// Only add platforms that have devices, especially in case all
162+
// devices are banned
163+
if (!URPlatform->Devices.empty()) {
164+
Adapter->URPlatforms.emplace_back(URPlatform.release());
165+
}
111166
}
112-
Adapter->NumPlatforms = NumPlatforms;
167+
Adapter->NumPlatforms =
168+
static_cast<uint32_t>(Adapter->URPlatforms.size());
113169
} catch (std::bad_alloc &) {
114170
return UR_RESULT_ERROR_OUT_OF_RESOURCES;
115171
} catch (...) {
@@ -217,11 +273,19 @@ ur_result_t ur_platform_handle_t_::InitDevices() {
217273

218274
CL_RETURN_ON_FAILURE(Res);
219275

276+
// Filter out banned devices
277+
std::vector<cl_device_id> FilteredDevices;
278+
for (uint32_t i = 0; i < DeviceNum; i++) {
279+
if (!isBannedOpenCLDevice(CLDevices[i])) {
280+
FilteredDevices.push_back(CLDevices[i]);
281+
}
282+
}
283+
220284
try {
221-
Devices.resize(DeviceNum);
222-
for (size_t i = 0; i < DeviceNum; i++) {
223-
Devices[i] =
224-
std::make_unique<ur_device_handle_t_>(CLDevices[i], this, nullptr);
285+
Devices.resize(FilteredDevices.size());
286+
for (size_t i = 0; i < FilteredDevices.size(); i++) {
287+
Devices[i] = std::make_unique<ur_device_handle_t_>(FilteredDevices[i],
288+
this, nullptr);
225289
}
226290
} catch (std::bad_alloc &) {
227291
return UR_RESULT_ERROR_OUT_OF_RESOURCES;

0 commit comments

Comments
 (0)