Skip to content

Commit 7756d71

Browse files
[NFC][SYCL] Explicit types/more comments around getOrBuild (#19556)
1 parent b0304e1 commit 7756d71

File tree

2 files changed

+26
-22
lines changed

2 files changed

+26
-22
lines changed

sycl/source/detail/kernel_program_cache.hpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ class KernelProgramCache {
404404
std::pair<std::shared_ptr<ProgramBuildResult>, bool>
405405
getOrInsertProgram(const ProgramCacheKeyT &CacheKey) {
406406
auto LockedCache = acquireCachedPrograms();
407-
auto &ProgCache = LockedCache.get();
407+
ProgramCache &ProgCache = LockedCache.get();
408408
auto [It, DidInsert] = ProgCache.Cache.try_emplace(CacheKey, nullptr);
409409
if (DidInsert) {
410410
It->second = std::make_shared<ProgramBuildResult>(getAdapter());
@@ -426,7 +426,7 @@ class KernelProgramCache {
426426
bool insertBuiltProgram(const ProgramCacheKeyT &CacheKey,
427427
ur_program_handle_t Program) {
428428
auto LockedCache = acquireCachedPrograms();
429-
auto &ProgCache = LockedCache.get();
429+
ProgramCache &ProgCache = LockedCache.get();
430430
auto [It, DidInsert] = ProgCache.Cache.try_emplace(CacheKey, nullptr);
431431
if (DidInsert) {
432432
It->second = std::make_shared<ProgramBuildResult>(getAdapter(),
@@ -491,7 +491,7 @@ class KernelProgramCache {
491491
// Save kernel in fast cache only if the corresponding program is also
492492
// in the cache.
493493
auto LockedCache = acquireCachedPrograms();
494-
auto &ProgCache = LockedCache.get();
494+
ProgramCache &ProgCache = LockedCache.get();
495495
if (ProgCache.ProgramSizeMap.find(CacheVal->MProgramHandle) ==
496496
ProgCache.ProgramSizeMap.end())
497497
return;
@@ -631,7 +631,7 @@ class KernelProgramCache {
631631
while (CurrCacheSize > DesiredCacheSize && !MEvictionList.empty()) {
632632
ProgramCacheKeyT CacheKey = ProgramEvictionList.front();
633633
auto LockedCache = acquireCachedPrograms();
634-
auto &ProgCache = LockedCache.get();
634+
ProgramCache &ProgCache = LockedCache.get();
635635
CurrCacheSize = removeProgramByKey(CacheKey, ProgCache);
636636
// Remove the program from the eviction list.
637637
MEvictionList.popFront();
@@ -748,15 +748,23 @@ class KernelProgramCache {
748748
///
749749
/// \return a pointer to cached build result, return value must not be
750750
/// nullptr.
751+
///
752+
/// Note that build result might be immediately evicted (if it's bigger than
753+
/// current threshold), so the caller *must* assume (potentially shared)
754+
/// ownership. In other words, `std::shared_ptr` in the return type is
755+
/// unavoidable.
751756
template <errc Errc, typename GetCachedBuildFT, typename BuildFT,
752757
typename EvictFT = void *>
753-
auto getOrBuild(GetCachedBuildFT &&GetCachedBuild, BuildFT &&Build,
754-
EvictFT &&EvictFunc = nullptr) {
758+
auto /* std::shared_ptr<BuildResult> */
759+
getOrBuild(GetCachedBuildFT &&GetCachedBuild, BuildFT &&Build,
760+
EvictFT &&EvictFunc = nullptr) {
755761
using BuildState = KernelProgramCache::BuildState;
756762
constexpr size_t MaxAttempts = 2;
757763
for (size_t AttemptCounter = 0;; ++AttemptCounter) {
758-
auto Res = GetCachedBuild();
764+
auto /* std::pair<std::shared_ptr<BuildResult>, bool> */ Res =
765+
GetCachedBuild();
759766
auto &BuildResult = Res.first;
767+
assert(BuildResult != nullptr);
760768
BuildState Expected = BuildState::BS_Initial;
761769
BuildState Desired = BuildState::BS_InProgress;
762770
if (!BuildResult->State.compare_exchange_strong(Expected, Desired)) {
@@ -825,7 +833,7 @@ class KernelProgramCache {
825833

826834
void removeAllRelatedEntries(uint32_t ImageId) {
827835
auto LockedCache = acquireCachedPrograms();
828-
auto &ProgCache = LockedCache.get();
836+
ProgramCache &ProgCache = LockedCache.get();
829837

830838
auto It = std::find_if(
831839
ProgCache.KeyMap.begin(), ProgCache.KeyMap.end(),

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,15 +1016,13 @@ ProgramManager::getBuiltURProgram(const BinImgWithDeps &ImgWithDeps,
10161016
};
10171017

10181018
auto EvictFunc = [&Cache, &CacheKey](ur_program_handle_t Program,
1019-
bool isBuilt) {
1020-
return Cache.registerProgramFetch(CacheKey, Program, isBuilt);
1019+
bool isBuilt) -> void {
1020+
Cache.registerProgramFetch(CacheKey, Program, isBuilt);
10211021
};
10221022

1023-
auto BuildResult =
1023+
std::shared_ptr<KernelProgramCache::ProgramBuildResult> BuildResult =
10241024
Cache.getOrBuild<errc::build>(GetCachedBuildF, BuildF, EvictFunc);
1025-
1026-
// getOrBuild is not supposed to return nullptr
1027-
assert(BuildResult != nullptr && "Invalid build result");
1025+
assert(BuildResult && "getOrBuild isn't supposed to return nullptr!");
10281026

10291027
ur_program_handle_t ResProgram = BuildResult->Val;
10301028

@@ -1082,8 +1080,6 @@ ProgramManager::getBuiltURProgram(const BinImgWithDeps &ImgWithDeps,
10821080
Adapter.call<UrApiKind::urProgramRetain>(ResProgram);
10831081
}
10841082
CacheLinkedImages();
1085-
// getOrBuild is not supposed to return nullptr
1086-
assert(BuildResult != nullptr && "Invalid build result");
10871083
}
10881084
}
10891085

@@ -1155,9 +1151,9 @@ FastKernelCacheValPtr ProgramManager::getOrCreateKernel(
11551151
Kernel, nullptr, ArgMask, Program, ContextImpl.getAdapter());
11561152
}
11571153

1158-
auto BuildResult = Cache.getOrBuild<errc::invalid>(GetCachedBuildF, BuildF);
1159-
// getOrBuild is not supposed to return nullptr
1160-
assert(BuildResult != nullptr && "Invalid build result");
1154+
std::shared_ptr<KernelProgramCache::KernelBuildResult> BuildResult =
1155+
Cache.getOrBuild<errc::invalid>(GetCachedBuildF, BuildF);
1156+
assert(BuildResult && "getOrBuild isn't supposed to return nullptr!");
11611157
const std::pair<ur_kernel_handle_t, const KernelArgMask *>
11621158
&KernelArgMaskPair = BuildResult->Val;
11631159
auto ret_val = std::make_shared<FastKernelCacheVal>(
@@ -3192,9 +3188,9 @@ ProgramManager::getOrCreateKernel(const context &Context,
31923188
return make_tuple(Kernel, nullptr, ArgMask);
31933189
}
31943190

3195-
auto BuildResult = Cache.getOrBuild<errc::invalid>(GetCachedBuildF, BuildF);
3196-
// getOrBuild is not supposed to return nullptr
3197-
assert(BuildResult != nullptr && "Invalid build result");
3191+
std::shared_ptr<KernelProgramCache::KernelBuildResult> BuildResult =
3192+
Cache.getOrBuild<errc::invalid>(GetCachedBuildF, BuildF);
3193+
assert(BuildResult && "getOrBuild isn't supposed to return nullptr!");
31983194
// If caching is enabled, one copy of the kernel handle will be
31993195
// stored in the cache, and one handle is returned to the
32003196
// caller. In that case, we need to increase the ref count of the

0 commit comments

Comments
 (0)