Skip to content

Commit 46296dc

Browse files
[SYCL] Fix adjusted kernel name handling in preview RT build (#19582)
Remove usage of a temporary string where the program manager receives and stores a string view in the preview build.
1 parent 0de8e83 commit 46296dc

File tree

4 files changed

+53
-28
lines changed

4 files changed

+53
-28
lines changed

sycl/source/detail/device_image_impl.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,18 @@ std::shared_ptr<kernel_impl> device_image_impl::tryGetExtensionKernel(
2121
!((getOriginMask() & ImageOriginSYCLBIN) && hasKernelName(Name)))
2222
return nullptr;
2323

24-
std::string AdjustedName = adjustKernelName(Name);
24+
std::string_view AdjustedName = getAdjustedKernelNameStrView(Name);
2525
if (MRTCBinInfo && MRTCBinInfo->MLanguage == syclex::source_language::sycl) {
2626
auto &PM = ProgramManager::getInstance();
2727
for (const std::string &Prefix : MRTCBinInfo->MPrefixes) {
28-
auto KID = PM.tryGetSYCLKernelID(Prefix + AdjustedName);
28+
auto KID = PM.tryGetSYCLKernelID(Prefix + std::string(AdjustedName));
2929

3030
if (!KID || !has_kernel(*KID))
3131
continue;
3232

3333
auto UrProgram = get_ur_program();
3434
auto [UrKernel, CacheMutex, ArgMask] =
35-
PM.getOrCreateKernel(Context, AdjustedName,
35+
PM.getOrCreateKernel(Context, KernelNameStrT(AdjustedName),
3636
/*PropList=*/{}, UrProgram);
3737
return std::make_shared<kernel_impl>(
3838
std::move(UrKernel), *getSyclObjImpl(Context), shared_from_this(),
@@ -44,7 +44,7 @@ std::shared_ptr<kernel_impl> device_image_impl::tryGetExtensionKernel(
4444
ur_program_handle_t UrProgram = get_ur_program();
4545
detail::adapter_impl &Adapter = getSyclObjImpl(Context)->getAdapter();
4646
Managed<ur_kernel_handle_t> UrKernel{Adapter};
47-
Adapter.call<UrApiKind::urKernelCreate>(UrProgram, AdjustedName.c_str(),
47+
Adapter.call<UrApiKind::urKernelCreate>(UrProgram, AdjustedName.data(),
4848
&UrKernel);
4949

5050
const KernelArgMask *ArgMask = nullptr;

sycl/source/detail/device_image_impl.hpp

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,12 @@ class ManagedDeviceBinaries {
152152
sycl_device_binaries MBinaries;
153153
};
154154

155+
// Using ordered containers for heterogenous lookup.
156+
// TODO change to unordered containers after switching to C++20.
155157
using MangledKernelNameMapT = std::map<std::string, std::string, std::less<>>;
156158
using KernelNameSetT = std::set<std::string, std::less<>>;
157-
using KernelNameToArgMaskMap = std::unordered_map<std::string, KernelArgMask>;
159+
using KernelNameToArgMaskMap =
160+
std::map<std::string, KernelArgMask, std::less<>>;
158161

159162
// Information unique to images compiled at runtime through the
160163
// ext_oneapi_kernel_compiler extension.
@@ -619,32 +622,21 @@ class device_image_impl
619622
#pragma warning(pop)
620623
#endif
621624

622-
std::string adjustKernelName(std::string_view Name) const {
623-
if (MOrigins & ImageOriginSYCLBIN) {
624-
constexpr std::string_view KernelPrefix = "__sycl_kernel_";
625-
if (Name.size() > KernelPrefix.size() &&
626-
Name.substr(0, KernelPrefix.size()) == KernelPrefix)
627-
return Name.data();
628-
return std::string{KernelPrefix} + Name.data();
629-
}
630-
631-
if (!MRTCBinInfo.has_value())
632-
return Name.data();
633-
634-
if (MRTCBinInfo->MLanguage == syclex::source_language::sycl) {
635-
auto It = MRTCBinInfo->MMangledKernelNames.find(Name);
636-
if (It != MRTCBinInfo->MMangledKernelNames.end())
637-
return It->second;
638-
}
625+
// Assumes the kernel is contained within this image.
626+
std::string_view getAdjustedKernelNameStrView(std::string_view Name) const {
627+
return getAdjustedKernelNameImpl<std::string_view>(Name);
628+
}
639629

640-
return Name.data();
630+
std::string getAdjustedKernelNameStr(std::string_view Name) const {
631+
return getAdjustedKernelNameImpl<std::string>(Name);
641632
}
642633

643634
bool hasKernelName(std::string_view Name) const {
644635
return (getOriginMask() &
645636
(ImageOriginKernelCompiler | ImageOriginSYCLBIN)) &&
646637
!Name.empty() &&
647-
MKernelNames.find(adjustKernelName(Name)) != MKernelNames.end();
638+
MKernelNames.find(getAdjustedKernelNameStr(Name)) !=
639+
MKernelNames.end();
648640
}
649641

650642
std::shared_ptr<kernel_impl>
@@ -840,6 +832,37 @@ class device_image_impl
840832
}
841833

842834
private:
835+
template <typename RetT>
836+
RetT getAdjustedKernelNameImpl(std::string_view Name) const {
837+
if (MOrigins & ImageOriginSYCLBIN) {
838+
constexpr std::string_view KernelPrefix = "__sycl_kernel_";
839+
if (Name.size() > KernelPrefix.size() &&
840+
Name.substr(0, KernelPrefix.size()) == KernelPrefix)
841+
return RetT(Name);
842+
std::string AdjustedNameStr =
843+
std::string(KernelPrefix) + std::string(Name);
844+
if constexpr (std::is_same_v<RetT, std::string>) {
845+
return AdjustedNameStr;
846+
} else {
847+
static_assert(std::is_same_v<RetT, std::string_view>);
848+
auto It = MKernelNames.find(AdjustedNameStr);
849+
assert(It != MKernelNames.end() && "Adjusted name not found");
850+
return *It;
851+
}
852+
}
853+
854+
if (!MRTCBinInfo.has_value())
855+
return RetT(Name);
856+
857+
if (MRTCBinInfo->MLanguage == syclex::source_language::sycl) {
858+
auto It = MRTCBinInfo->MMangledKernelNames.find(Name);
859+
if (It != MRTCBinInfo->MMangledKernelNames.end())
860+
return It->second;
861+
}
862+
863+
return RetT(Name);
864+
}
865+
843866
bool hasRTDeviceBinaryImage() const noexcept {
844867
return std::holds_alternative<const RTDeviceBinaryImage *>(MBinImage) &&
845868
get_bin_image_ref() != nullptr;

sycl/source/detail/kernel_bundle_impl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -692,7 +692,7 @@ class kernel_bundle_impl
692692
throw sycl::exception(make_error_code(errc::invalid),
693693
"kernel '" + Name + "' not found in kernel_bundle");
694694

695-
return It->adjustKernelName(Name);
695+
return It->getAdjustedKernelNameStr(Name);
696696
}
697697

698698
bool ext_oneapi_has_device_global(const std::string &Name) const {

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2816,7 +2816,7 @@ ProgramManager::compile(const DevImgPlainWithDeps &ImgWithDeps,
28162816
setSpecializationConstants(InputImpl, Prog, Adapter);
28172817

28182818
KernelNameSetT KernelNames = InputImpl.getKernelNames();
2819-
std::unordered_map<std::string, KernelArgMask> EliminatedKernelArgMasks =
2819+
std::map<std::string, KernelArgMask, std::less<>> EliminatedKernelArgMasks =
28202820
InputImpl.getEliminatedKernelArgMasks();
28212821

28222822
std::optional<detail::KernelCompilerBinaryInfo> RTCInfo =
@@ -3006,7 +3006,8 @@ ProgramManager::link(const std::vector<device_image_plain> &Imgs,
30063006
RTCInfoPtrs;
30073007
RTCInfoPtrs.reserve(Imgs.size());
30083008
KernelNameSetT MergedKernelNames;
3009-
std::unordered_map<std::string, KernelArgMask> MergedEliminatedKernelArgMasks;
3009+
std::map<std::string, KernelArgMask, std::less<>>
3010+
MergedEliminatedKernelArgMasks;
30103011
for (const device_image_plain &DevImg : Imgs) {
30113012
device_image_impl &DevImgImpl = *getSyclObjImpl(DevImg);
30123013
CombinedOrigins |= DevImgImpl.getOriginMask();
@@ -3088,7 +3089,8 @@ ProgramManager::build(const DevImgPlainWithDeps &DevImgWithDeps,
30883089
RTCInfoPtrs;
30893090
RTCInfoPtrs.reserve(DevImgWithDeps.size());
30903091
KernelNameSetT MergedKernelNames;
3091-
std::unordered_map<std::string, KernelArgMask> MergedEliminatedKernelArgMasks;
3092+
std::map<std::string, KernelArgMask, std::less<>>
3093+
MergedEliminatedKernelArgMasks;
30923094
for (const device_image_plain &DevImg : DevImgWithDeps) {
30933095
device_image_impl &DevImgImpl = *getSyclObjImpl(DevImg);
30943096
RTCInfoPtrs.emplace_back(&(DevImgImpl.getRTCInfo()));

0 commit comments

Comments
 (0)