Skip to content

Commit c04057c

Browse files
committed
Implement cache eviction
1 parent 8307767 commit c04057c

File tree

3 files changed

+306
-10
lines changed

3 files changed

+306
-10
lines changed

sycl/source/detail/kernel_program_cache.hpp

Lines changed: 272 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
#include <atomic>
2222
#include <condition_variable>
2323
#include <iomanip>
24+
#include <list>
2425
#include <mutex>
26+
#include <numeric>
2527
#include <set>
2628
#include <thread>
2729
#include <type_traits>
@@ -130,7 +132,15 @@ class KernelProgramCache {
130132
struct ProgramCache {
131133
::boost::unordered_map<ProgramCacheKeyT, ProgramBuildResultPtr> Cache;
132134
::boost::unordered_multimap<CommonProgramKeyT, ProgramCacheKeyT> KeyMap;
135+
// Mapping between a UR program and its size.
136+
::boost::unordered_map<ur_program_handle_t, size_t> ProgramSizeMap;
133137

138+
size_t ProgramCacheSizeInBytes = 0;
139+
inline size_t GetProgramCacheSizeInBytes() const noexcept {
140+
return ProgramCacheSizeInBytes;
141+
}
142+
143+
// Returns number of entries in the cache.
134144
size_t size() const noexcept { return Cache.size(); }
135145
};
136146

@@ -175,6 +185,54 @@ class KernelProgramCache {
175185
using KernelFastCacheT =
176186
::boost::unordered_flat_map<KernelFastCacheKeyT, KernelFastCacheValT>;
177187

188+
// DS to hold data and functions related to Program cache eviction.
189+
struct EvictionListT {
190+
// Linked list of cache entries to be evicted in case of cache overflow.
191+
std::list<ProgramCacheKeyT> MProgramEvictionList;
192+
193+
// Mapping between program handle and the iterator to the eviction list.
194+
::boost::unordered_map<ProgramCacheKeyT,
195+
std::list<ProgramCacheKeyT>::iterator>
196+
MProgramToEvictionListMap;
197+
198+
void clear() {
199+
MProgramEvictionList.clear();
200+
MProgramToEvictionListMap.clear();
201+
}
202+
203+
void emplaceBack(const ProgramCacheKeyT &CacheKey) {
204+
MProgramEvictionList.emplace_back(CacheKey);
205+
206+
// In std::list, the iterators are not invalidated when elements are
207+
// added/removed/moved to the list. So, we can safely store the iterators.
208+
MProgramToEvictionListMap[CacheKey] =
209+
std::prev(MProgramEvictionList.end());
210+
traceProgram("Program added to the end of eviction list.", CacheKey);
211+
}
212+
213+
// This function is called on the hot path, whenever a kernel/program
214+
// is accessed. So, it should be very fast.
215+
void moveToEnd(const ProgramCacheKeyT &CacheKey) {
216+
auto It = MProgramToEvictionListMap.find(CacheKey);
217+
if (It != MProgramToEvictionListMap.end()) {
218+
MProgramEvictionList.splice(MProgramEvictionList.end(),
219+
MProgramEvictionList, It->second);
220+
traceProgram("Program moved to the end of eviction list.", CacheKey);
221+
} else
222+
// This should never happen.
223+
assert(false && "Program not found in the eviction list.");
224+
}
225+
226+
bool empty() { return MProgramEvictionList.empty(); }
227+
228+
void popFront() {
229+
if (!MProgramEvictionList.empty()) {
230+
MProgramToEvictionListMap.erase(MProgramEvictionList.front());
231+
MProgramEvictionList.pop_front();
232+
}
233+
}
234+
};
235+
178236
~KernelProgramCache() = default;
179237

180238
void setContextPtr(const ContextPtr &AContext) { MParentContext = AContext; }
@@ -188,12 +246,23 @@ class KernelProgramCache {
188246

189247
int ImageId = CacheKey.first.second;
190248
std::stringstream DeviceList;
249+
std::vector<unsigned char> SerializedObjVec = CacheKey.first.first;
250+
251+
// Convert spec constants to string. Spec constants are stored as
252+
// ASCII values, so we need need to convert them to int and then to
253+
// string.
254+
std::string SerializedObjString;
255+
for (unsigned char c : SerializedObjVec)
256+
SerializedObjString += std::to_string((int)c) + ",";
257+
191258
for (const auto &Device : CacheKey.second)
192259
DeviceList << "0x" << std::setbase(16)
193260
<< reinterpret_cast<uintptr_t>(Device) << ",";
194261

195262
std::string Identifier = "[Key:{imageId = " + std::to_string(ImageId) +
196-
",urDevice = " + DeviceList.str() + "}]: ";
263+
",urDevice = " + DeviceList.str() +
264+
", serializedObj = " + SerializedObjString +
265+
"}]: ";
197266

198267
std::cerr << "[In-Memory Cache][Thread Id:" << std::this_thread::get_id()
199268
<< "][Program Cache]" << Identifier << Msg << std::endl;
@@ -259,8 +328,7 @@ class KernelProgramCache {
259328
std::make_pair(CacheKey.first.second, CacheKey.second);
260329
ProgCache.KeyMap.emplace(CommonKey, CacheKey);
261330
traceProgram("Program inserted.", CacheKey);
262-
} else
263-
traceProgram("Program fetched.", CacheKey);
331+
}
264332
return DidInsert;
265333
}
266334

@@ -291,11 +359,186 @@ class KernelProgramCache {
291359

292360
template <typename KeyT, typename ValT>
293361
void saveKernel(KeyT &&CacheKey, ValT &&CacheVal) {
362+
363+
ur_program_handle_t Program = std::get<3>(CacheVal);
364+
// Save kernel in fast cache only if the corresponding program is also
365+
// in the cache.
366+
{
367+
auto LockedCache = acquireCachedPrograms();
368+
auto &ProgCache = LockedCache.get();
369+
if (ProgCache.ProgramSizeMap.find(Program) ==
370+
ProgCache.ProgramSizeMap.end())
371+
return;
372+
}
373+
294374
std::unique_lock<std::mutex> Lock(MKernelFastCacheMutex);
295375
// if no insertion took place, thus some other thread has already inserted
296376
// smth in the cache
297377
traceKernel("Kernel inserted.", std::get<3>(CacheKey), true);
298378
MKernelFastCache.emplace(CacheKey, CacheVal);
379+
380+
if (SYCLConfig<
381+
SYCL_IN_MEM_CACHE_EVICTION_THRESHOLD>::getProgramCacheSize()) {
382+
383+
// Save reference between the program and the fast cache key.
384+
MProgramToKernelFastCacheKeyMap[Program].emplace_back(CacheKey);
385+
}
386+
}
387+
388+
// Evict programs from cache to free up space.
389+
void evictPrograms(int DesiredCacheSize, int CurrentCacheSize) {
390+
391+
// Figure out how many programs from the beginning we need to evict.
392+
// [FIXME] Will the race on MCachedPrograms.Cache.empty() be a problem?
393+
size_t EvictSize = CurrentCacheSize - DesiredCacheSize;
394+
if (EvictSize <= 0 || MCachedPrograms.Cache.empty())
395+
return;
396+
397+
// Evict programs from the beginning of the cache.
398+
{
399+
std::lock_guard<std::mutex> Lock(MProgramEvictionListMutex);
400+
401+
// Traverse the eviction list and remove the LRU programs.
402+
// The LRU programs will be at the front of the list.
403+
while (EvictSize > 0 && !MEvictionList.empty()) {
404+
ProgramCacheKeyT CacheKey = MEvictionList.MProgramEvictionList.front();
405+
auto LockedCache = acquireCachedPrograms();
406+
auto &ProgCache = LockedCache.get();
407+
auto It = ProgCache.Cache.find(CacheKey);
408+
if (It != ProgCache.Cache.end()) {
409+
// We are about to remove this program now.
410+
// (1) Remove it from KernelPerProgram cache.
411+
// (2) Remove corresponding entries from KernelFastCache.
412+
// (3) Remove it from ProgramCache KeyMap.
413+
// (4) Remove it from the ProgramCache.
414+
// (5) Remove it from ProgramSizeMap.
415+
// (6) Update the cache size.
416+
417+
// Remove entry from the KernelsPerProgram cache.
418+
ur_program_handle_t NativePrg = It->second->Val;
419+
{
420+
auto LockedCacheKP = acquireKernelsPerProgramCache();
421+
LockedCacheKP.get().erase(NativePrg);
422+
}
423+
424+
// Remove corresponding entries from KernelFastCache.
425+
auto FastCacheKeyItr =
426+
MProgramToKernelFastCacheKeyMap.find(NativePrg);
427+
if (FastCacheKeyItr != MProgramToKernelFastCacheKeyMap.end()) {
428+
for (const auto &FastCacheKey : FastCacheKeyItr->second) {
429+
std::unique_lock<std::mutex> Lock(MKernelFastCacheMutex);
430+
MKernelFastCache.erase(FastCacheKey);
431+
traceKernel("Kernel evicted.", std::get<3>(FastCacheKey), true);
432+
}
433+
MProgramToKernelFastCacheKeyMap.erase(FastCacheKeyItr);
434+
}
435+
436+
// Remove entry from ProgramCache KeyMap.
437+
CommonProgramKeyT CommonKey =
438+
std::make_pair(CacheKey.first.second, CacheKey.second);
439+
// Since KeyMap is a multi-map, we need to iterate over all entries
440+
// with this CommonKey and remove those that match the CacheKey.
441+
auto KeyMapItrRange = LockedCache.get().KeyMap.equal_range(CommonKey);
442+
for (auto KeyMapItr = KeyMapItrRange.first;
443+
KeyMapItr != KeyMapItrRange.second; ++KeyMapItr) {
444+
if (KeyMapItr->second == CacheKey) {
445+
LockedCache.get().KeyMap.erase(KeyMapItr);
446+
break;
447+
}
448+
}
449+
450+
// Get size of the program.
451+
size_t ProgramSize = MCachedPrograms.ProgramSizeMap[It->second->Val];
452+
// Evict program from the cache.
453+
ProgCache.Cache.erase(It);
454+
// Remove program size from the cache size.
455+
MCachedPrograms.ProgramCacheSizeInBytes -= ProgramSize;
456+
// Remove program size from the eviction list.
457+
EvictSize -= ProgramSize;
458+
459+
MCachedPrograms.ProgramSizeMap.erase(NativePrg);
460+
461+
traceProgram("Program evicted.", CacheKey);
462+
} else
463+
// This should never happen.
464+
assert(false && "Program not found in the cache.");
465+
466+
// Remove the program from the eviction list.
467+
MEvictionList.popFront();
468+
}
469+
}
470+
}
471+
472+
// Register that a program has been fetched from the cache.
473+
// If it is the first time the program is fetched, add it to the eviction
474+
// list.
475+
void registerProgramFetch(const ProgramCacheKeyT &CacheKey,
476+
const ur_program_handle_t &Program,
477+
const bool IsBuilt) {
478+
479+
static size_t ProgramCacheEvictionThreshold = static_cast<size_t>(
480+
SYCLConfig<
481+
SYCL_IN_MEM_CACHE_EVICTION_THRESHOLD>::getProgramCacheSize());
482+
483+
// No need to populate the eviction list if eviction is disabled.
484+
if (ProgramCacheEvictionThreshold == 0)
485+
return;
486+
487+
// If the program is not in the cache, add it to the cache.
488+
if (IsBuilt) {
489+
// This is the first time we are adding this entry. Add it to the end of
490+
// eviction list.
491+
{
492+
std::lock_guard<std::mutex> Lock(MProgramEvictionListMutex);
493+
MEvictionList.emplaceBack(CacheKey);
494+
}
495+
496+
// Store size of the program and check if we need to evict some entries.
497+
// Get Size of the program.
498+
size_t ProgramSize;
499+
auto Adapter = getAdapter();
500+
501+
try {
502+
// Get number of devices this program was built for.
503+
unsigned int DeviceNum = 0;
504+
Adapter->call<UrApiKind::urProgramGetInfo>(
505+
Program, UR_PROGRAM_INFO_NUM_DEVICES, sizeof(DeviceNum), &DeviceNum,
506+
nullptr);
507+
508+
// Get binary sizes for each device.
509+
std::vector<size_t> BinarySizes(DeviceNum);
510+
Adapter->call<UrApiKind::urProgramGetInfo>(
511+
Program, UR_PROGRAM_INFO_BINARY_SIZES,
512+
sizeof(size_t) * BinarySizes.size(), BinarySizes.data(), nullptr);
513+
514+
// Sum up binary sizes.
515+
ProgramSize =
516+
std::accumulate(BinarySizes.begin(), BinarySizes.end(), 0);
517+
} catch (const exception &Ex) {
518+
std::cerr << "Failed to get program size: " << Ex.what() << std::endl;
519+
std::rethrow_exception(std::current_exception());
520+
}
521+
// Store program size in the cache.
522+
size_t CurrCacheSize = 0;
523+
{
524+
std::lock_guard<std::mutex> Lock(MProgramCacheMutex);
525+
MCachedPrograms.ProgramSizeMap[Program] = ProgramSize;
526+
MCachedPrograms.ProgramCacheSizeInBytes += ProgramSize;
527+
CurrCacheSize = MCachedPrograms.ProgramCacheSizeInBytes;
528+
}
529+
530+
// Evict programs if the cache size exceeds the threshold.
531+
if (CurrCacheSize > ProgramCacheEvictionThreshold)
532+
evictPrograms(ProgramCacheEvictionThreshold, CurrCacheSize);
533+
}
534+
// If the program is already in the cache, move it to the end of the list.
535+
// Since we are following LRU eviction policy, we need to move the program
536+
// to the end of the list. Items in the front of the list are the least
537+
// recently This code path is "hot" and should be very fast.
538+
else {
539+
std::lock_guard<std::mutex> Lock(MProgramEvictionListMutex);
540+
MEvictionList.moveToEnd(CacheKey);
541+
}
299542
}
300543

301544
/// Clears cache state.
@@ -308,6 +551,11 @@ class KernelProgramCache {
308551
MCachedPrograms = ProgramCache{};
309552
MKernelsPerProgramCache = KernelCacheT{};
310553
MKernelFastCache = KernelFastCacheT{};
554+
MProgramToKernelFastCacheKeyMap.clear();
555+
556+
// Clear the eviction lists and its mutexes.
557+
std::lock_guard<std::mutex> L4(MProgramEvictionListMutex);
558+
MEvictionList.clear();
311559
}
312560

313561
/// Try to fetch entity (kernel or program) from cache. If there is no such
@@ -332,8 +580,10 @@ class KernelProgramCache {
332580
///
333581
/// \return a pointer to cached build result, return value must not be
334582
/// nullptr.
335-
template <errc Errc, typename GetCachedBuildFT, typename BuildFT>
336-
auto getOrBuild(GetCachedBuildFT &&GetCachedBuild, BuildFT &&Build) {
583+
template <errc Errc, typename GetCachedBuildFT, typename BuildFT,
584+
typename EvicFT = void *>
585+
auto getOrBuild(GetCachedBuildFT &&GetCachedBuild, BuildFT &&Build,
586+
EvicFT &&EvictFnc = nullptr) {
337587
using BuildState = KernelProgramCache::BuildState;
338588
constexpr size_t MaxAttempts = 2;
339589
for (size_t AttemptCounter = 0;; ++AttemptCounter) {
@@ -347,8 +597,11 @@ class KernelProgramCache {
347597
BuildState NewState = BuildResult->waitUntilTransition();
348598

349599
// Build succeeded.
350-
if (NewState == BuildState::BS_Done)
600+
if (NewState == BuildState::BS_Done) {
601+
if constexpr (!std::is_same_v<EvicFT, void *>)
602+
EvictFnc(BuildResult->Val, 0);
351603
return BuildResult;
604+
}
352605

353606
// Build failed, or this is the last attempt.
354607
if (NewState == BuildState::BS_Failed ||
@@ -372,6 +625,9 @@ class KernelProgramCache {
372625
try {
373626
BuildResult->Val = Build();
374627

628+
if constexpr (!std::is_same_v<EvicFT, void *>)
629+
EvictFnc(BuildResult->Val, 1);
630+
375631
BuildResult->updateAndNotify(BuildState::BS_Done);
376632
return BuildResult;
377633
} catch (const exception &Ex) {
@@ -405,6 +661,16 @@ class KernelProgramCache {
405661

406662
std::mutex MKernelFastCacheMutex;
407663
KernelFastCacheT MKernelFastCache;
664+
665+
// Map between fast kernel cache keys and program handle.
666+
// MKernelFastCacheMutex will be used for synchronization.
667+
std::unordered_map<ur_program_handle_t, std::vector<KernelFastCacheKeyT>>
668+
MProgramToKernelFastCacheKeyMap;
669+
670+
EvictionListT MEvictionList;
671+
// Mutexes that will be used when accessing the eviction lists.
672+
std::mutex MProgramEvictionListMutex;
673+
408674
friend class ::MockKernelProgramCache;
409675

410676
const AdapterPtr &getAdapter();

0 commit comments

Comments
 (0)