Skip to content

Commit 15bb682

Browse files
committed
Address feedback
1 parent a79273c commit 15bb682

File tree

2 files changed

+43
-41
lines changed

2 files changed

+43
-41
lines changed

sycl/include/sycl/detail/os_util.hpp

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -106,21 +106,18 @@ void fileTreeWalk(const std::string Path,
106106
std::function<void(const std::string)> Func,
107107
bool ignoreErrors = false);
108108

109-
void *dynLookup(const char *WinName, const char *LinName,
110-
const char *LinuxFallbackLibName, const char *FunName);
111-
112109
// Look up a function name that was dynamically linked
113-
// This is used by the runtime where it needs to manipulate native handles (e.g.
114-
// retaining OpenCL handles). On Windows, the symbol name is looked up in
115-
// `WinName`. In Linux, it uses `LinName` or `LinuxFallbackLibName`.
110+
// This is used by the runtime where it needs to manipulate native handles
111+
// (e.g. retaining OpenCL handles).
116112
//
117113
// The library must already have been loaded (perhaps by UR), otherwise this
118114
// function throws a SYCL runtime exception.
115+
void *dynLookup(const std::vector<const char *> &LibNames, const char *FunName);
116+
119117
template <typename fn>
120-
fn *dynLookupFunction(const char *WinName, const char *LinName,
121-
const char *LinuxFallbackLibName, const char *FunName) {
122-
return reinterpret_cast<fn *>(
123-
dynLookup(WinName, LinName, LinuxFallbackLibName, FunName));
118+
fn *dynLookupFunction(const std::vector<const char *> LibNames,
119+
const char *FunName) {
120+
return reinterpret_cast<fn *>(dynLookup(LibNames, FunName));
124121
}
125122

126123
// On Linux, first try to load from libur_adapter_opencl.so, then
@@ -129,10 +126,15 @@ fn *dynLookupFunction(const char *WinName, const char *LinName,
129126
// symlinked, which is the case with PyPi compiler distribution package.
130127
// We can't load libur_adapter_opencl.so.0 always as the first choice because
131128
// that would break SYCL unittests, which rely on mocking libur_adapter_opencl.
129+
#ifdef __SYCL_RT_OS_WINDOWS
130+
#define OCLLibNames {"OpenCL"}
131+
#else
132+
#define OCLLibNames {"libur_adapter_opencl.so", "libur_adapter_opencl.so.0"}
133+
#endif
134+
132135
#define __SYCL_OCL_CALL(FN, ...) \
133-
(sycl::_V1::detail::dynLookupFunction<decltype(FN)>( \
134-
"OpenCL", "libur_adapter_opencl.so", "libur_adapter_opencl.so.0", \
135-
#FN)(__VA_ARGS__))
136+
(sycl::_V1::detail::dynLookupFunction<decltype(FN)>(OCLLibNames, \
137+
#FN)(__VA_ARGS__))
136138

137139
} // namespace detail
138140
} // namespace _V1

sycl/source/detail/os_util.cpp

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -291,44 +291,44 @@ size_t getDirectorySize(const std::string &Path, bool ignoreErrors) {
291291
return DirSizeVar;
292292
}
293293

294-
// Look up a function name that was dynamically linked
295-
// This is used by the runtime where it needs to manipulate native handles (e.g.
296-
// retaining OpenCL handles). On Windows, the symbol name is looked up in
297-
// `WinName`. In Linux, it uses `LinName`.
294+
// Look up a function name from the given list of shared libraries.
298295
//
299-
// The library must already have been loaded (perhaps by UR), otherwise this
296+
// These library must already have been loaded (perhaps by UR), otherwise this
300297
// function throws a SYCL runtime exception.
301-
void *dynLookup([[maybe_unused]] const char *WinName,
302-
[[maybe_unused]] const char *LinName,
303-
[[maybe_unused]] const char *LinuxFallbackLibName,
298+
void *dynLookup(const std::vector<const char *> &LibNames,
304299
const char *FunName) {
305300
#ifdef __SYCL_RT_OS_WINDOWS
306-
auto handle = GetModuleHandleA(WinName);
307-
if (!handle) {
308-
throw sycl::exception(make_error_code(errc::runtime),
309-
std::string(WinName) + " library is not loaded");
310-
}
311-
auto *retVal = GetProcAddress(handle, FunName);
301+
HMODULE handle = nullptr;
302+
auto GetHandleF = [](const char *LibName) {
303+
return GetModuleHandleA(LibName);
304+
};
305+
auto GetProcF = [&]() { return GetProcAddress(handle, FunName); };
312306
#else
313-
auto handle = dlopen(LinName, RTLD_LAZY | RTLD_NOLOAD);
314-
if (!handle) {
307+
void *handle = nullptr;
308+
auto GetHandleF = [](const char *LibName) {
309+
return dlopen(LibName, RTLD_LAZY | RTLD_NOLOAD);
310+
};
311+
auto GetProcF = [&]() {
312+
auto *retVal = dlsym(handle, FunName);
313+
dlclose(handle);
314+
return retVal;
315+
};
316+
#endif
315317

316-
// Try to open fallback library if provided.
317-
if (LinuxFallbackLibName)
318-
handle = dlopen(LinuxFallbackLibName, RTLD_LAZY | RTLD_NOLOAD);
318+
// Iterate over the list of libraries and try to find one that is loaded.
319+
auto LibNameIt = LibNames.begin();
320+
while (!handle && LibNameIt != LibNames.end())
321+
handle = GetHandleF(*(LibNameIt++));
322+
if (!handle)
323+
throw sycl::exception(make_error_code(errc::runtime),
324+
"Libraries could not be loaded");
319325

320-
if (!handle)
321-
throw sycl::exception(make_error_code(errc::runtime),
322-
std::string(LinName) + " library is not loaded");
323-
}
324-
auto *retVal = dlsym(handle, FunName);
325-
dlclose(handle);
326-
#endif
327-
if (!retVal) {
326+
// Look up the function in the loaded library.
327+
auto *retVal = GetProcF();
328+
if (!retVal)
328329
throw sycl::exception(make_error_code(errc::runtime),
329330
"Symbol " + std::string(FunName) +
330331
" could not be found");
331-
}
332332
return reinterpret_cast<void *>(retVal);
333333
}
334334

0 commit comments

Comments
 (0)