Skip to content

Commit 467f2cf

Browse files
committed
[UR][OpenCL] Move banned platforms filtering to UR
Filter out banned OpenCl platforms earlier, in the UR OpenCL adapter instead of the SYCL RT.
1 parent ce54584 commit 467f2cf

File tree

2 files changed

+42
-45
lines changed

2 files changed

+42
-45
lines changed

sycl/source/detail/platform_impl.cpp

Lines changed: 2 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -88,33 +88,6 @@ context_impl &platform_impl::khr_get_default_context() {
8888
return *It->second;
8989
}
9090

91-
static bool IsBannedPlatform(platform Platform) {
92-
// The NVIDIA OpenCL platform is currently not compatible with DPC++
93-
// since it is only 1.2 but gets selected by default in many systems
94-
// There is also no support on the PTX backend for OpenCL consumption,
95-
// and there have been some internal reports.
96-
// To avoid problems on default users and deployment of DPC++ on platforms
97-
// where CUDA is available, the OpenCL support is disabled.
98-
//
99-
// There is also no support for the AMD HSA backend for OpenCL consumption,
100-
// as well as reported problems with device queries, so AMD OpenCL support
101-
// is disabled as well.
102-
//
103-
auto IsMatchingOpenCL = [](platform Platform, const std::string_view name) {
104-
const bool HasNameMatch = Platform.get_info<info::platform::name>().find(
105-
name) != std::string::npos;
106-
const auto Backend = detail::getSyclObjImpl(Platform)->getBackend();
107-
const bool IsMatchingOCL = (HasNameMatch && Backend == backend::opencl);
108-
if (detail::ur::trace(detail::ur::TraceLevel::TRACE_ALL) && IsMatchingOCL) {
109-
std::cout << "SYCL_UR_TRACE: " << name
110-
<< " OpenCL platform found but is not compatible." << std::endl;
111-
}
112-
return IsMatchingOCL;
113-
};
114-
return IsMatchingOpenCL(Platform, "NVIDIA CUDA") ||
115-
IsMatchingOpenCL(Platform, "AMD Accelerated Parallel Processing");
116-
}
117-
11891
// Get the vector of platforms supported by a given UR adapter
11992
// replace uses of this with a helper in adapter object, the adapter
12093
// objects will own the ur adapter handles and they'll need to pass them to
@@ -132,25 +105,13 @@ std::vector<platform> platform_impl::getAdapterPlatforms(adapter_impl &Adapter,
132105
for (const auto &UrPlatform : UrPlatforms) {
133106
platform Platform = detail::createSyclObjFromImpl<platform>(
134107
getOrMakePlatformImpl(UrPlatform, Adapter));
135-
const bool IsBanned = IsBannedPlatform(Platform);
136-
bool HasAnyDevices = false;
137-
138-
// Platform.get_devices() increments the device count for the platform
139-
// and if the platform is banned (like OpenCL for AMD), it can cause
140-
// incorrect device numbering, when used with ONEAPI_DEVICE_SELECTOR.
141-
if (!IsBanned)
142-
HasAnyDevices = !Platform.get_devices(info::device_type::all).empty();
108+
bool HasAnyDevices = !Platform.get_devices(info::device_type::all).empty();
143109

144110
if (!Supported) {
145-
if (IsBanned || !HasAnyDevices) {
111+
if (!HasAnyDevices) {
146112
Platforms.push_back(std::move(Platform));
147113
}
148114
} else {
149-
if (IsBanned) {
150-
continue; // bail as early as possible, otherwise banned platforms may
151-
// mess up device counting
152-
}
153-
154115
// The SYCL spec says that a platform has one or more devices. ( SYCL
155116
// 2020 4.6.2 ) If we have an empty platform, we don't report it back
156117
// from platform::get_platforms().

unified-runtime/source/adapters/opencl/platform.cpp

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,35 @@ static cl_int mapURPlatformInfoToCL(ur_platform_info_t URPropName) {
3030
}
3131
}
3232

33+
static bool isBannedOpenCLPlatform(cl_platform_id platform) {
34+
size_t nameSize = 0;
35+
cl_int res =
36+
clGetPlatformInfo(platform, CL_PLATFORM_NAME, 0, nullptr, &nameSize);
37+
if (res != CL_SUCCESS || nameSize == 0) {
38+
return false;
39+
}
40+
41+
std::string name(nameSize, '\0');
42+
res = clGetPlatformInfo(platform, CL_PLATFORM_NAME, nameSize, name.data(),
43+
nullptr);
44+
if (res != CL_SUCCESS) {
45+
return false;
46+
}
47+
48+
// The NVIDIA OpenCL platform is currently not compatible with DPC++
49+
// since it is only 1.2 but gets selected by default in many systems.
50+
// There is also no support on the PTX backend for OpenCL consumption.
51+
//
52+
// There is also no support for the AMD HSA backend for OpenCL consumption,
53+
// as well as reported problems with device queries, so AMD OpenCL support
54+
// is disabled as well.
55+
bool isBanned =
56+
name.find("NVIDIA CUDA") != std::string::npos ||
57+
name.find("AMD Accelerated Parallel Processing") != std::string::npos;
58+
59+
return isBanned;
60+
}
61+
3362
UR_DLLEXPORT ur_result_t UR_APICALL
3463
urPlatformGetInfo(ur_platform_handle_t hPlatform, ur_platform_info_t propName,
3564
size_t propSize, void *pPropValue, size_t *pSizeRet) {
@@ -102,14 +131,21 @@ urPlatformGet(ur_adapter_handle_t, uint32_t NumEntries,
102131
CLPlatforms.data(), nullptr);
103132
CL_RETURN_ON_FAILURE(Res);
104133

134+
// Filter out banned platforms
135+
std::vector<cl_platform_id> FilteredPlatforms;
136+
for (uint32_t i = 0; i < NumPlatforms; i++) {
137+
if (!isBannedOpenCLPlatform(CLPlatforms[i])) {
138+
FilteredPlatforms.push_back(CLPlatforms[i]);
139+
}
140+
}
141+
105142
try {
106-
for (uint32_t i = 0; i < NumPlatforms; i++) {
107-
auto URPlatform =
108-
std::make_unique<ur_platform_handle_t_>(CLPlatforms[i]);
143+
for (auto &Platform : FilteredPlatforms) {
144+
auto URPlatform = std::make_unique<ur_platform_handle_t_>(Platform);
109145
UR_RETURN_ON_FAILURE(URPlatform->InitDevices());
110146
Adapter->URPlatforms.emplace_back(URPlatform.release());
111147
}
112-
Adapter->NumPlatforms = NumPlatforms;
148+
Adapter->NumPlatforms = static_cast<uint32_t>(FilteredPlatforms.size());
113149
} catch (std::bad_alloc &) {
114150
return UR_RESULT_ERROR_OUT_OF_RESOURCES;
115151
} catch (...) {

0 commit comments

Comments
 (0)