@@ -294,6 +294,11 @@ class KernelProgramCache {
294294 MProgramEvictionList.pop_front ();
295295 }
296296 }
297+
298+ void erase (const ProgramCacheKeyT &CacheKey) {
299+ MProgramEvictionList.remove (CacheKey);
300+ MProgramToEvictionListMap.erase (CacheKey);
301+ }
297302 };
298303
299304 ~KernelProgramCache () = default ;
@@ -427,31 +432,100 @@ class KernelProgramCache {
427432
428433 template <typename KeyT, typename ValT>
429434 void saveKernel (KeyT &&CacheKey, ValT &&CacheVal) {
430-
435+ ur_program_handle_t Program = std::get< 3 >(CacheVal);
431436 if (SYCLConfig<SYCL_IN_MEM_CACHE_EVICTION_THRESHOLD>::
432437 isProgramCacheEvictionEnabled ()) {
433438
434- ur_program_handle_t Program = std::get<3 >(CacheVal);
435439 // Save kernel in fast cache only if the corresponding program is also
436440 // in the cache.
437441 auto LockedCache = acquireCachedPrograms ();
438442 auto &ProgCache = LockedCache.get ();
439443 if (ProgCache.ProgramSizeMap .find (Program) ==
440444 ProgCache.ProgramSizeMap .end ())
441445 return ;
442-
443- // Save reference between the program and the fast cache key.
444- std::unique_lock<std::mutex> Lock (MKernelFastCacheMutex);
445- MProgramToKernelFastCacheKeyMap[Program].emplace_back (CacheKey);
446446 }
447-
447+ // Save reference between the program and the fast cache key.
448448 std::unique_lock<std::mutex> Lock (MKernelFastCacheMutex);
449+ MProgramToKernelFastCacheKeyMap[Program].emplace_back (CacheKey);
450+
449451 // if no insertion took place, thus some other thread has already inserted
450452 // smth in the cache
451453 traceKernel (" Kernel inserted." , CacheKey.second , true );
452454 MKernelFastCache.emplace (CacheKey, CacheVal);
453455 }
454456
457+ // Expects locked program cache
458+ size_t removeProgramByKey (const ProgramCacheKeyT &CacheKey,
459+ ProgramCache &ProgCache) {
460+ auto It = ProgCache.Cache .find (CacheKey);
461+
462+ if (It != ProgCache.Cache .end ()) {
463+ // We are about to remove this program now.
464+ // (1) Remove it from KernelPerProgram cache.
465+ // (2) Remove corresponding entries from KernelFastCache.
466+ // (3) Remove it from ProgramCache KeyMap.
467+ // (4) Remove it from the ProgramCache.
468+ // (5) Remove it from ProgramSizeMap.
469+ // (6) Update the cache size.
470+
471+ // Remove entry from the KernelsPerProgram cache.
472+ ur_program_handle_t NativePrg = It->second ->Val ;
473+ {
474+ auto LockedCacheKP = acquireKernelsPerProgramCache ();
475+ // List kernels that are to be removed from the cache, if tracing is
476+ // enabled.
477+ if (SYCLConfig<SYCL_CACHE_TRACE>::isTraceInMemCache ()) {
478+ for (const auto &Kernel : LockedCacheKP.get ()[NativePrg])
479+ traceKernel (" Kernel evicted." , Kernel.first );
480+ }
481+ LockedCacheKP.get ().erase (NativePrg);
482+ }
483+
484+ {
485+ // Remove corresponding entries from KernelFastCache.
486+ std::unique_lock<std::mutex> Lock (MKernelFastCacheMutex);
487+ if (auto FastCacheKeyItr =
488+ MProgramToKernelFastCacheKeyMap.find (NativePrg);
489+ FastCacheKeyItr != MProgramToKernelFastCacheKeyMap.end ()) {
490+ for (const auto &FastCacheKey : FastCacheKeyItr->second ) {
491+ MKernelFastCache.erase (FastCacheKey);
492+ traceKernel (" Kernel evicted." , FastCacheKey.second , true );
493+ }
494+ MProgramToKernelFastCacheKeyMap.erase (FastCacheKeyItr);
495+ }
496+ }
497+
498+ // Remove entry from ProgramCache KeyMap.
499+ CommonProgramKeyT CommonKey =
500+ std::make_pair (CacheKey.first .second , CacheKey.second );
501+ // Since KeyMap is a multi-map, we need to iterate over all entries
502+ // with this CommonKey and remove those that match the CacheKey.
503+ auto KeyMapItrRange = ProgCache.KeyMap .equal_range (CommonKey);
504+ for (auto KeyMapItr = KeyMapItrRange.first ;
505+ KeyMapItr != KeyMapItrRange.second ; ++KeyMapItr) {
506+ if (KeyMapItr->second == CacheKey) {
507+ ProgCache.KeyMap .erase (KeyMapItr);
508+ break ;
509+ }
510+ }
511+
512+ // Get size of the program.
513+ size_t ProgramSize = MCachedPrograms.ProgramSizeMap [It->second ->Val ];
514+ // Evict program from the cache.
515+ ProgCache.Cache .erase (It);
516+ // Remove program size from the cache size.
517+ MCachedPrograms.ProgramCacheSizeInBytes -= ProgramSize;
518+ MCachedPrograms.ProgramSizeMap .erase (NativePrg);
519+
520+ traceProgram (" Program evicted." , CacheKey);
521+ } else
522+ // This should never happen.
523+ throw sycl::exception (sycl::make_error_code (sycl::errc::runtime),
524+ " Program not found in the cache." );
525+
526+ return MCachedPrograms.ProgramCacheSizeInBytes ;
527+ }
528+
455529 // Evict programs from cache to free up space.
456530 void evictPrograms (size_t DesiredCacheSize, size_t CurrentCacheSize) {
457531
@@ -471,73 +545,7 @@ class KernelProgramCache {
471545 ProgramCacheKeyT CacheKey = ProgramEvictionList.front ();
472546 auto LockedCache = acquireCachedPrograms ();
473547 auto &ProgCache = LockedCache.get ();
474- auto It = ProgCache.Cache .find (CacheKey);
475-
476- if (It != ProgCache.Cache .end ()) {
477- // We are about to remove this program now.
478- // (1) Remove it from KernelPerProgram cache.
479- // (2) Remove corresponding entries from KernelFastCache.
480- // (3) Remove it from ProgramCache KeyMap.
481- // (4) Remove it from the ProgramCache.
482- // (5) Remove it from ProgramSizeMap.
483- // (6) Update the cache size.
484-
485- // Remove entry from the KernelsPerProgram cache.
486- ur_program_handle_t NativePrg = It->second ->Val ;
487- {
488- auto LockedCacheKP = acquireKernelsPerProgramCache ();
489- // List kernels that are to be removed from the cache, if tracing is
490- // enabled.
491- if (SYCLConfig<SYCL_CACHE_TRACE>::isTraceInMemCache ()) {
492- for (const auto &Kernel : LockedCacheKP.get ()[NativePrg])
493- traceKernel (" Kernel evicted." , Kernel.first );
494- }
495- LockedCacheKP.get ().erase (NativePrg);
496- }
497-
498- {
499- // Remove corresponding entries from KernelFastCache.
500- std::unique_lock<std::mutex> Lock (MKernelFastCacheMutex);
501- if (auto FastCacheKeyItr =
502- MProgramToKernelFastCacheKeyMap.find (NativePrg);
503- FastCacheKeyItr != MProgramToKernelFastCacheKeyMap.end ()) {
504- for (const auto &FastCacheKey : FastCacheKeyItr->second ) {
505- MKernelFastCache.erase (FastCacheKey);
506- traceKernel (" Kernel evicted." , FastCacheKey.second , true );
507- }
508- MProgramToKernelFastCacheKeyMap.erase (FastCacheKeyItr);
509- }
510- }
511-
512- // Remove entry from ProgramCache KeyMap.
513- CommonProgramKeyT CommonKey =
514- std::make_pair (CacheKey.first .second , CacheKey.second );
515- // Since KeyMap is a multi-map, we need to iterate over all entries
516- // with this CommonKey and remove those that match the CacheKey.
517- auto KeyMapItrRange = LockedCache.get ().KeyMap .equal_range (CommonKey);
518- for (auto KeyMapItr = KeyMapItrRange.first ;
519- KeyMapItr != KeyMapItrRange.second ; ++KeyMapItr) {
520- if (KeyMapItr->second == CacheKey) {
521- LockedCache.get ().KeyMap .erase (KeyMapItr);
522- break ;
523- }
524- }
525-
526- // Get size of the program.
527- size_t ProgramSize = MCachedPrograms.ProgramSizeMap [It->second ->Val ];
528- // Evict program from the cache.
529- ProgCache.Cache .erase (It);
530- // Remove program size from the cache size.
531- MCachedPrograms.ProgramCacheSizeInBytes -= ProgramSize;
532- MCachedPrograms.ProgramSizeMap .erase (NativePrg);
533-
534- traceProgram (" Program evicted." , CacheKey);
535- } else
536- // This should never happen.
537- throw sycl::exception (sycl::make_error_code (sycl::errc::runtime),
538- " Program not found in the cache." );
539-
540- CurrCacheSize = MCachedPrograms.ProgramCacheSizeInBytes ;
548+ CurrCacheSize = removeProgramByKey (CacheKey, ProgCache);
541549 // Remove the program from the eviction list.
542550 MEvictionList.popFront ();
543551 }
@@ -724,6 +732,24 @@ class KernelProgramCache {
724732 }
725733 }
726734
735+ void removeAllRelatedEntries (uint32_t ImageId) {
736+ auto LockedCache = acquireCachedPrograms ();
737+ auto &ProgCache = LockedCache.get ();
738+
739+ auto It = std::find_if (
740+ ProgCache.KeyMap .begin (), ProgCache.KeyMap .end (),
741+ [&ImageId](const auto &Entry) { return ImageId == Entry.first .first ; });
742+ if (It == ProgCache.KeyMap .end ())
743+ return ;
744+
745+ auto Key = It->second ;
746+ removeProgramByKey (Key, ProgCache);
747+ {
748+ auto LockedEvictionList = acquireEvictionList ();
749+ LockedEvictionList.get ().erase (Key);
750+ }
751+ }
752+
727753private:
728754 std::mutex MProgramCacheMutex;
729755 std::mutex MKernelsPerProgramCacheMutex;
0 commit comments