Skip to content

Commit f262cbd

Browse files
committed
[UR][OpenCL] Move banned platforms filtering to UR
Filter out banned OpenCl platforms earlier, in the UR OpenCL adapter instead of the SYCL RT.
1 parent ce54584 commit f262cbd

File tree

2 files changed

+39
-44
lines changed

2 files changed

+39
-44
lines changed

sycl/source/detail/platform_impl.cpp

Lines changed: 2 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -88,33 +88,6 @@ context_impl &platform_impl::khr_get_default_context() {
8888
return *It->second;
8989
}
9090

91-
static bool IsBannedPlatform(platform Platform) {
92-
// The NVIDIA OpenCL platform is currently not compatible with DPC++
93-
// since it is only 1.2 but gets selected by default in many systems
94-
// There is also no support on the PTX backend for OpenCL consumption,
95-
// and there have been some internal reports.
96-
// To avoid problems on default users and deployment of DPC++ on platforms
97-
// where CUDA is available, the OpenCL support is disabled.
98-
//
99-
// There is also no support for the AMD HSA backend for OpenCL consumption,
100-
// as well as reported problems with device queries, so AMD OpenCL support
101-
// is disabled as well.
102-
//
103-
auto IsMatchingOpenCL = [](platform Platform, const std::string_view name) {
104-
const bool HasNameMatch = Platform.get_info<info::platform::name>().find(
105-
name) != std::string::npos;
106-
const auto Backend = detail::getSyclObjImpl(Platform)->getBackend();
107-
const bool IsMatchingOCL = (HasNameMatch && Backend == backend::opencl);
108-
if (detail::ur::trace(detail::ur::TraceLevel::TRACE_ALL) && IsMatchingOCL) {
109-
std::cout << "SYCL_UR_TRACE: " << name
110-
<< " OpenCL platform found but is not compatible." << std::endl;
111-
}
112-
return IsMatchingOCL;
113-
};
114-
return IsMatchingOpenCL(Platform, "NVIDIA CUDA") ||
115-
IsMatchingOpenCL(Platform, "AMD Accelerated Parallel Processing");
116-
}
117-
11891
// Get the vector of platforms supported by a given UR adapter
11992
// replace uses of this with a helper in adapter object, the adapter
12093
// objects will own the ur adapter handles and they'll need to pass them to
@@ -132,25 +105,13 @@ std::vector<platform> platform_impl::getAdapterPlatforms(adapter_impl &Adapter,
132105
for (const auto &UrPlatform : UrPlatforms) {
133106
platform Platform = detail::createSyclObjFromImpl<platform>(
134107
getOrMakePlatformImpl(UrPlatform, Adapter));
135-
const bool IsBanned = IsBannedPlatform(Platform);
136-
bool HasAnyDevices = false;
137-
138-
// Platform.get_devices() increments the device count for the platform
139-
// and if the platform is banned (like OpenCL for AMD), it can cause
140-
// incorrect device numbering, when used with ONEAPI_DEVICE_SELECTOR.
141-
if (!IsBanned)
142-
HasAnyDevices = !Platform.get_devices(info::device_type::all).empty();
108+
bool HasAnyDevices = !Platform.get_devices(info::device_type::all).empty();
143109

144110
if (!Supported) {
145-
if (IsBanned || !HasAnyDevices) {
111+
if (!HasAnyDevices) {
146112
Platforms.push_back(std::move(Platform));
147113
}
148114
} else {
149-
if (IsBanned) {
150-
continue; // bail as early as possible, otherwise banned platforms may
151-
// mess up device counting
152-
}
153-
154115
// The SYCL spec says that a platform has one or more devices. ( SYCL
155116
// 2020 4.6.2 ) If we have an empty platform, we don't report it back
156117
// from platform::get_platforms().

unified-runtime/source/adapters/opencl/platform.cpp

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,32 @@ static cl_int mapURPlatformInfoToCL(ur_platform_info_t URPropName) {
3030
}
3131
}
3232

33+
static bool isBannedOpenCLPlatform(cl_platform_id platform) {
34+
size_t nameSize = 0;
35+
cl_int res = clGetPlatformInfo(platform, CL_PLATFORM_NAME, 0, nullptr, &nameSize);
36+
if (res != CL_SUCCESS || nameSize == 0) {
37+
return false;
38+
}
39+
40+
std::string name(nameSize, '\0');
41+
res = clGetPlatformInfo(platform, CL_PLATFORM_NAME, nameSize, name.data(), nullptr);
42+
if (res != CL_SUCCESS) {
43+
return false;
44+
}
45+
46+
// The NVIDIA OpenCL platform is currently not compatible with DPC++
47+
// since it is only 1.2 but gets selected by default in many systems.
48+
// There is also no support on the PTX backend for OpenCL consumption.
49+
//
50+
// There is also no support for the AMD HSA backend for OpenCL consumption,
51+
// as well as reported problems with device queries, so AMD OpenCL support
52+
// is disabled as well.
53+
bool isBanned = name.find("NVIDIA CUDA") != std::string::npos ||
54+
name.find("AMD Accelerated Parallel Processing") != std::string::npos;
55+
56+
return isBanned;
57+
}
58+
3359
UR_DLLEXPORT ur_result_t UR_APICALL
3460
urPlatformGetInfo(ur_platform_handle_t hPlatform, ur_platform_info_t propName,
3561
size_t propSize, void *pPropValue, size_t *pSizeRet) {
@@ -102,14 +128,22 @@ urPlatformGet(ur_adapter_handle_t, uint32_t NumEntries,
102128
CLPlatforms.data(), nullptr);
103129
CL_RETURN_ON_FAILURE(Res);
104130

131+
// Filter out banned platforms
132+
std::vector<cl_platform_id> FilteredPlatforms;
133+
for (uint32_t i = 0; i < NumPlatforms; i++) {
134+
if (!isBannedOpenCLPlatform(CLPlatforms[i])) {
135+
FilteredPlatforms.push_back(CLPlatforms[i]);
136+
}
137+
}
138+
105139
try {
106-
for (uint32_t i = 0; i < NumPlatforms; i++) {
140+
for (auto &Platform : FilteredPlatforms) {
107141
auto URPlatform =
108-
std::make_unique<ur_platform_handle_t_>(CLPlatforms[i]);
142+
std::make_unique<ur_platform_handle_t_>(Platform);
109143
UR_RETURN_ON_FAILURE(URPlatform->InitDevices());
110144
Adapter->URPlatforms.emplace_back(URPlatform.release());
111145
}
112-
Adapter->NumPlatforms = NumPlatforms;
146+
Adapter->NumPlatforms = static_cast<uint32_t>(FilteredPlatforms.size());
113147
} catch (std::bad_alloc &) {
114148
return UR_RESULT_ERROR_OUT_OF_RESOURCES;
115149
} catch (...) {

0 commit comments

Comments
 (0)