Skip to content

Commit 3157885

Browse files
Check kernel info ownership with mixed image origins
1 parent 65e213b commit 3157885

File tree

4 files changed

+22
-2
lines changed

4 files changed

+22
-2
lines changed

sycl/source/detail/kernel_impl.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,11 @@ kernel_impl::kernel_impl(Managed<ur_kernel_handle_t> &&Kernel,
5959
MKernelBundleImpl(KernelBundleImpl.shared_from_this()),
6060
MIsInterop(MDeviceImageImpl->getOriginMask() & ImageOriginInterop),
6161
MKernelArgMaskPtr{ArgMask}, MCacheMutex{CacheMutex},
62-
MOwnsDeviceKernelInfo(MDeviceImageImpl->getOriginMask() &
63-
~ImageOriginSYCLOffline),
62+
MOwnsDeviceKernelInfo(checkOwnsDeviceKernelInfo()),
6463
MDeviceKernelInfo(MOwnsDeviceKernelInfo
6564
? createCompileTimeKernelInfo(getName())
6665
: createCompileTimeKernelInfo()) {
66+
6767
// Enable USM indirect access for interop and non-sycl-jit source kernels.
6868
// sycl-jit kernels will enable this if needed through the regular kernel
6969
// path.
@@ -123,6 +123,17 @@ std::string_view kernel_impl::getName() const {
123123
return MName;
124124
}
125125

126+
bool kernel_impl::checkOwnsDeviceKernelInfo() {
127+
// If the image originates from something other than standard offline
128+
// compilation, this kernel needs to own its info structure.
129+
// We could also have a mixed origin image, in which case the device kernel
130+
// info might reside in program manager.
131+
return MDeviceImageImpl->getOriginMask() != ImageOriginSYCLOffline &&
132+
(!(MDeviceImageImpl->getOriginMask() & ImageOriginSYCLOffline) ||
133+
!ProgramManager::getInstance().tryGetDeviceKernelInfo(
134+
static_cast<KernelNameStrT>(getName())));
135+
}
136+
126137
bool kernel_impl::isBuiltInKernel(device_impl &Device) const {
127138
auto BuiltInKernels = Device.get_info<info::device::built_in_kernel_ids>();
128139
if (BuiltInKernels.empty())

sycl/source/detail/kernel_impl.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ class kernel_impl {
239239
std::mutex *getCacheMutex() const { return MCacheMutex; }
240240
std::string_view getName() const;
241241

242+
bool checkOwnsDeviceKernelInfo();
242243
DeviceKernelInfo &getDeviceKernelInfo() {
243244
return MOwnsDeviceKernelInfo
244245
? MDeviceKernelInfo

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1833,6 +1833,13 @@ ProgramManager::getDeviceKernelInfo(KernelNameStrRefT KernelName) {
18331833
return It->second;
18341834
}
18351835

1836+
DeviceKernelInfo *
1837+
ProgramManager::tryGetDeviceKernelInfo(KernelNameStrRefT KernelName) {
1838+
std::lock_guard<std::mutex> Guard(m_DeviceKernelInfoMapMutex);
1839+
auto It = m_DeviceKernelInfoMap.find(KernelName);
1840+
return It != m_DeviceKernelInfoMap.end() ? &It->second : nullptr;
1841+
}
1842+
18361843
static bool isBfloat16DeviceLibImage(sycl_device_binary RawImg,
18371844
uint32_t *LibVersion = nullptr) {
18381845
sycl_device_binary_property_set ImgPS;

sycl/source/detail/program_manager/program_manager.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,7 @@ class ProgramManager {
372372

373373
DeviceKernelInfo &getDeviceKernelInfo(const CompileTimeKernelInfoTy &Info);
374374
DeviceKernelInfo &getDeviceKernelInfo(KernelNameStrRefT KernelName);
375+
DeviceKernelInfo *tryGetDeviceKernelInfo(KernelNameStrRefT KernelName);
375376

376377
std::set<const RTDeviceBinaryImage *>
377378
getRawDeviceImages(const std::vector<kernel_id> &KernelIDs);

0 commit comments

Comments
 (0)