2121#include < atomic>
2222#include < condition_variable>
2323#include < iomanip>
24+ #include < list>
2425#include < mutex>
26+ #include < numeric>
2527#include < set>
2628#include < thread>
2729#include < type_traits>
@@ -130,7 +132,15 @@ class KernelProgramCache {
130132 struct ProgramCache {
131133 ::boost::unordered_map<ProgramCacheKeyT, ProgramBuildResultPtr> Cache;
132134 ::boost::unordered_multimap<CommonProgramKeyT, ProgramCacheKeyT> KeyMap;
135+ // Mapping between a UR program and its size.
136+ ::boost::unordered_map<ur_program_handle_t , size_t > ProgramSizeMap;
133137
138+ size_t ProgramCacheSizeInBytes = 0 ;
139+ inline size_t GetProgramCacheSizeInBytes () const noexcept {
140+ return ProgramCacheSizeInBytes;
141+ }
142+
143+ // Returns number of entries in the cache.
134144 size_t size () const noexcept { return Cache.size (); }
135145 };
136146
@@ -175,6 +185,54 @@ class KernelProgramCache {
175185 using KernelFastCacheT =
176186 ::boost::unordered_flat_map<KernelFastCacheKeyT, KernelFastCacheValT>;
177187
188+ // DS to hold data and functions related to Program cache eviction.
189+ struct EvictionListT {
190+ // Linked list of cache entries to be evicted in case of cache overflow.
191+ std::list<ProgramCacheKeyT> MProgramEvictionList;
192+
193+ // Mapping between program handle and the iterator to the eviction list.
194+ ::boost::unordered_map<ProgramCacheKeyT,
195+ std::list<ProgramCacheKeyT>::iterator>
196+ MProgramToEvictionListMap;
197+
198+ void clear () {
199+ MProgramEvictionList.clear ();
200+ MProgramToEvictionListMap.clear ();
201+ }
202+
203+ void emplaceBack (const ProgramCacheKeyT &CacheKey) {
204+ MProgramEvictionList.emplace_back (CacheKey);
205+
206+ // In std::list, the iterators are not invalidated when elements are
207+ // added/removed/moved to the list. So, we can safely store the iterators.
208+ MProgramToEvictionListMap[CacheKey] =
209+ std::prev (MProgramEvictionList.end ());
210+ traceProgram (" Program added to the end of eviction list." , CacheKey);
211+ }
212+
213+ // This function is called on the hot path, whenever a kernel/program
214+ // is accessed. So, it should be very fast.
215+ void moveToEnd (const ProgramCacheKeyT &CacheKey) {
216+ auto It = MProgramToEvictionListMap.find (CacheKey);
217+ if (It != MProgramToEvictionListMap.end ()) {
218+ MProgramEvictionList.splice (MProgramEvictionList.end (),
219+ MProgramEvictionList, It->second );
220+ traceProgram (" Program moved to the end of eviction list." , CacheKey);
221+ } else
222+ // This should never happen.
223+ assert (false && " Program not found in the eviction list." );
224+ }
225+
226+ bool empty () { return MProgramEvictionList.empty (); }
227+
228+ void popFront () {
229+ if (!MProgramEvictionList.empty ()) {
230+ MProgramToEvictionListMap.erase (MProgramEvictionList.front ());
231+ MProgramEvictionList.pop_front ();
232+ }
233+ }
234+ };
235+
178236 ~KernelProgramCache () = default ;
179237
180238 void setContextPtr (const ContextPtr &AContext) { MParentContext = AContext; }
@@ -188,12 +246,23 @@ class KernelProgramCache {
188246
189247 int ImageId = CacheKey.first .second ;
190248 std::stringstream DeviceList;
249+ std::vector<unsigned char > SerializedObjVec = CacheKey.first .first ;
250+
251+ // Convert spec constants to string. Spec constants are stored as
252+ // ASCII values, so we need need to convert them to int and then to
253+ // string.
254+ std::string SerializedObjString;
255+ for (unsigned char c : SerializedObjVec)
256+ SerializedObjString += std::to_string ((int )c) + " ," ;
257+
191258 for (const auto &Device : CacheKey.second )
192259 DeviceList << " 0x" << std::setbase (16 )
193260 << reinterpret_cast <uintptr_t >(Device) << " ," ;
194261
195262 std::string Identifier = " [Key:{imageId = " + std::to_string (ImageId) +
196- " ,urDevice = " + DeviceList.str () + " }]: " ;
263+ " ,urDevice = " + DeviceList.str () +
264+ " , serializedObj = " + SerializedObjString +
265+ " }]: " ;
197266
198267 std::cerr << " [In-Memory Cache][Thread Id:" << std::this_thread::get_id ()
199268 << " ][Program Cache]" << Identifier << Msg << std::endl;
@@ -259,8 +328,7 @@ class KernelProgramCache {
259328 std::make_pair (CacheKey.first .second , CacheKey.second );
260329 ProgCache.KeyMap .emplace (CommonKey, CacheKey);
261330 traceProgram (" Program inserted." , CacheKey);
262- } else
263- traceProgram (" Program fetched." , CacheKey);
331+ }
264332 return DidInsert;
265333 }
266334
@@ -291,11 +359,186 @@ class KernelProgramCache {
291359
292360 template <typename KeyT, typename ValT>
293361 void saveKernel (KeyT &&CacheKey, ValT &&CacheVal) {
362+
363+ ur_program_handle_t Program = std::get<3 >(CacheVal);
364+ // Save kernel in fast cache only if the corresponding program is also
365+ // in the cache.
366+ {
367+ auto LockedCache = acquireCachedPrograms ();
368+ auto &ProgCache = LockedCache.get ();
369+ if (ProgCache.ProgramSizeMap .find (Program) ==
370+ ProgCache.ProgramSizeMap .end ())
371+ return ;
372+ }
373+
294374 std::unique_lock<std::mutex> Lock (MKernelFastCacheMutex);
295375 // if no insertion took place, thus some other thread has already inserted
296376 // smth in the cache
297377 traceKernel (" Kernel inserted." , std::get<3 >(CacheKey), true );
298378 MKernelFastCache.emplace (CacheKey, CacheVal);
379+
380+ if (SYCLConfig<
381+ SYCL_IN_MEM_CACHE_EVICTION_THRESHOLD>::getProgramCacheSize ()) {
382+
383+ // Save reference between the program and the fast cache key.
384+ MProgramToKernelFastCacheKeyMap[Program].emplace_back (CacheKey);
385+ }
386+ }
387+
388+ // Evict programs from cache to free up space.
389+ void evictPrograms (int DesiredCacheSize, int CurrentCacheSize) {
390+
391+ // Figure out how many programs from the beginning we need to evict.
392+ // [FIXME] Will the race on MCachedPrograms.Cache.empty() be a problem?
393+ size_t EvictSize = CurrentCacheSize - DesiredCacheSize;
394+ if (EvictSize <= 0 || MCachedPrograms.Cache .empty ())
395+ return ;
396+
397+ // Evict programs from the beginning of the cache.
398+ {
399+ std::lock_guard<std::mutex> Lock (MProgramEvictionListMutex);
400+
401+ // Traverse the eviction list and remove the LRU programs.
402+ // The LRU programs will be at the front of the list.
403+ while (EvictSize > 0 && !MEvictionList.empty ()) {
404+ ProgramCacheKeyT CacheKey = MEvictionList.MProgramEvictionList .front ();
405+ auto LockedCache = acquireCachedPrograms ();
406+ auto &ProgCache = LockedCache.get ();
407+ auto It = ProgCache.Cache .find (CacheKey);
408+ if (It != ProgCache.Cache .end ()) {
409+ // We are about to remove this program now.
410+ // (1) Remove it from KernelPerProgram cache.
411+ // (2) Remove corresponding entries from KernelFastCache.
412+ // (3) Remove it from ProgramCache KeyMap.
413+ // (4) Remove it from the ProgramCache.
414+ // (5) Remove it from ProgramSizeMap.
415+ // (6) Update the cache size.
416+
417+ // Remove entry from the KernelsPerProgram cache.
418+ ur_program_handle_t NativePrg = It->second ->Val ;
419+ {
420+ auto LockedCacheKP = acquireKernelsPerProgramCache ();
421+ LockedCacheKP.get ().erase (NativePrg);
422+ }
423+
424+ // Remove corresponding entries from KernelFastCache.
425+ auto FastCacheKeyItr =
426+ MProgramToKernelFastCacheKeyMap.find (NativePrg);
427+ if (FastCacheKeyItr != MProgramToKernelFastCacheKeyMap.end ()) {
428+ for (const auto &FastCacheKey : FastCacheKeyItr->second ) {
429+ std::unique_lock<std::mutex> Lock (MKernelFastCacheMutex);
430+ MKernelFastCache.erase (FastCacheKey);
431+ traceKernel (" Kernel evicted." , std::get<3 >(FastCacheKey), true );
432+ }
433+ MProgramToKernelFastCacheKeyMap.erase (FastCacheKeyItr);
434+ }
435+
436+ // Remove entry from ProgramCache KeyMap.
437+ CommonProgramKeyT CommonKey =
438+ std::make_pair (CacheKey.first .second , CacheKey.second );
439+ // Since KeyMap is a multi-map, we need to iterate over all entries
440+ // with this CommonKey and remove those that match the CacheKey.
441+ auto KeyMapItrRange = LockedCache.get ().KeyMap .equal_range (CommonKey);
442+ for (auto KeyMapItr = KeyMapItrRange.first ;
443+ KeyMapItr != KeyMapItrRange.second ; ++KeyMapItr) {
444+ if (KeyMapItr->second == CacheKey) {
445+ LockedCache.get ().KeyMap .erase (KeyMapItr);
446+ break ;
447+ }
448+ }
449+
450+ // Get size of the program.
451+ size_t ProgramSize = MCachedPrograms.ProgramSizeMap [It->second ->Val ];
452+ // Evict program from the cache.
453+ ProgCache.Cache .erase (It);
454+ // Remove program size from the cache size.
455+ MCachedPrograms.ProgramCacheSizeInBytes -= ProgramSize;
456+ // Remove program size from the eviction list.
457+ EvictSize -= ProgramSize;
458+
459+ MCachedPrograms.ProgramSizeMap .erase (NativePrg);
460+
461+ traceProgram (" Program evicted." , CacheKey);
462+ } else
463+ // This should never happen.
464+ assert (false && " Program not found in the cache." );
465+
466+ // Remove the program from the eviction list.
467+ MEvictionList.popFront ();
468+ }
469+ }
470+ }
471+
472+ // Register that a program has been fetched from the cache.
473+ // If it is the first time the program is fetched, add it to the eviction
474+ // list.
475+ void registerProgramFetch (const ProgramCacheKeyT &CacheKey,
476+ const ur_program_handle_t &Program,
477+ const bool IsBuilt) {
478+
479+ static size_t ProgramCacheEvictionThreshold = static_cast <size_t >(
480+ SYCLConfig<
481+ SYCL_IN_MEM_CACHE_EVICTION_THRESHOLD>::getProgramCacheSize ());
482+
483+ // No need to populate the eviction list if eviction is disabled.
484+ if (ProgramCacheEvictionThreshold == 0 )
485+ return ;
486+
487+ // If the program is not in the cache, add it to the cache.
488+ if (IsBuilt) {
489+ // This is the first time we are adding this entry. Add it to the end of
490+ // eviction list.
491+ {
492+ std::lock_guard<std::mutex> Lock (MProgramEvictionListMutex);
493+ MEvictionList.emplaceBack (CacheKey);
494+ }
495+
496+ // Store size of the program and check if we need to evict some entries.
497+ // Get Size of the program.
498+ size_t ProgramSize;
499+ auto Adapter = getAdapter ();
500+
501+ try {
502+ // Get number of devices this program was built for.
503+ unsigned int DeviceNum = 0 ;
504+ Adapter->call <UrApiKind::urProgramGetInfo>(
505+ Program, UR_PROGRAM_INFO_NUM_DEVICES, sizeof (DeviceNum), &DeviceNum,
506+ nullptr );
507+
508+ // Get binary sizes for each device.
509+ std::vector<size_t > BinarySizes (DeviceNum);
510+ Adapter->call <UrApiKind::urProgramGetInfo>(
511+ Program, UR_PROGRAM_INFO_BINARY_SIZES,
512+ sizeof (size_t ) * BinarySizes.size (), BinarySizes.data (), nullptr );
513+
514+ // Sum up binary sizes.
515+ ProgramSize =
516+ std::accumulate (BinarySizes.begin (), BinarySizes.end (), 0 );
517+ } catch (const exception &Ex) {
518+ std::cerr << " Failed to get program size: " << Ex.what () << std::endl;
519+ std::rethrow_exception (std::current_exception ());
520+ }
521+ // Store program size in the cache.
522+ size_t CurrCacheSize = 0 ;
523+ {
524+ std::lock_guard<std::mutex> Lock (MProgramCacheMutex);
525+ MCachedPrograms.ProgramSizeMap [Program] = ProgramSize;
526+ MCachedPrograms.ProgramCacheSizeInBytes += ProgramSize;
527+ CurrCacheSize = MCachedPrograms.ProgramCacheSizeInBytes ;
528+ }
529+
530+ // Evict programs if the cache size exceeds the threshold.
531+ if (CurrCacheSize > ProgramCacheEvictionThreshold)
532+ evictPrograms (ProgramCacheEvictionThreshold, CurrCacheSize);
533+ }
534+ // If the program is already in the cache, move it to the end of the list.
535+ // Since we are following LRU eviction policy, we need to move the program
536+ // to the end of the list. Items in the front of the list are the least
537+ // recently This code path is "hot" and should be very fast.
538+ else {
539+ std::lock_guard<std::mutex> Lock (MProgramEvictionListMutex);
540+ MEvictionList.moveToEnd (CacheKey);
541+ }
299542 }
300543
301544 // / Clears cache state.
@@ -308,6 +551,11 @@ class KernelProgramCache {
308551 MCachedPrograms = ProgramCache{};
309552 MKernelsPerProgramCache = KernelCacheT{};
310553 MKernelFastCache = KernelFastCacheT{};
554+ MProgramToKernelFastCacheKeyMap.clear ();
555+
556+ // Clear the eviction lists and its mutexes.
557+ std::lock_guard<std::mutex> L4 (MProgramEvictionListMutex);
558+ MEvictionList.clear ();
311559 }
312560
313561 // / Try to fetch entity (kernel or program) from cache. If there is no such
@@ -332,8 +580,10 @@ class KernelProgramCache {
332580 // /
333581 // / \return a pointer to cached build result, return value must not be
334582 // / nullptr.
335- template <errc Errc, typename GetCachedBuildFT, typename BuildFT>
336- auto getOrBuild (GetCachedBuildFT &&GetCachedBuild, BuildFT &&Build) {
583+ template <errc Errc, typename GetCachedBuildFT, typename BuildFT,
584+ typename EvicFT = void *>
585+ auto getOrBuild (GetCachedBuildFT &&GetCachedBuild, BuildFT &&Build,
586+ EvicFT &&EvictFnc = nullptr ) {
337587 using BuildState = KernelProgramCache::BuildState;
338588 constexpr size_t MaxAttempts = 2 ;
339589 for (size_t AttemptCounter = 0 ;; ++AttemptCounter) {
@@ -347,8 +597,11 @@ class KernelProgramCache {
347597 BuildState NewState = BuildResult->waitUntilTransition ();
348598
349599 // Build succeeded.
350- if (NewState == BuildState::BS_Done)
600+ if (NewState == BuildState::BS_Done) {
601+ if constexpr (!std::is_same_v<EvicFT, void *>)
602+ EvictFnc (BuildResult->Val , 0 );
351603 return BuildResult;
604+ }
352605
353606 // Build failed, or this is the last attempt.
354607 if (NewState == BuildState::BS_Failed ||
@@ -372,6 +625,9 @@ class KernelProgramCache {
372625 try {
373626 BuildResult->Val = Build ();
374627
628+ if constexpr (!std::is_same_v<EvicFT, void *>)
629+ EvictFnc (BuildResult->Val , 1 );
630+
375631 BuildResult->updateAndNotify (BuildState::BS_Done);
376632 return BuildResult;
377633 } catch (const exception &Ex) {
@@ -405,6 +661,16 @@ class KernelProgramCache {
405661
406662 std::mutex MKernelFastCacheMutex;
407663 KernelFastCacheT MKernelFastCache;
664+
665+ // Map between fast kernel cache keys and program handle.
666+ // MKernelFastCacheMutex will be used for synchronization.
667+ std::unordered_map<ur_program_handle_t , std::vector<KernelFastCacheKeyT>>
668+ MProgramToKernelFastCacheKeyMap;
669+
670+ EvictionListT MEvictionList;
671+ // Mutexes that will be used when accessing the eviction lists.
672+ std::mutex MProgramEvictionListMutex;
673+
408674 friend class ::MockKernelProgramCache;
409675
410676 const AdapterPtr &getAdapter ();
0 commit comments