Skip to content

Commit dee5fe4

Browse files
Rework LRU cache
Fix issue with std::unordered_map iterator invalidation Use std::list instead of std::deque to keep eviction list
1 parent b870738 commit dee5fe4

File tree

1 file changed

+77
-35
lines changed

1 file changed

+77
-35
lines changed

Common/interface/LRUCache.hpp

Lines changed: 77 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
#pragma once
1919

2020
#include <unordered_map>
21-
#include <deque>
21+
#include <list>
2222
#include <mutex>
2323
#include <memory>
2424
#include <algorithm>
@@ -54,6 +54,9 @@ namespace Diligent
5454
///
5555
/// If the data is not found, it is atomically initialized by the provided initializer function.
5656
/// If the data is found, the initializer function is not called.
57+
///
58+
/// \note The initialization function must not call Get() on the same cache instance
59+
/// to avoid potential deadlocks.
5760
template <typename KeyType, typename DataType, typename KeyHasher = std::hash<KeyType>>
5861
class LRUCache
5962
{
@@ -98,7 +101,7 @@ class LRUCache
98101
bool IsNewObject = false;
99102
// InitData may throw, which will leave the wrapper in the cache in the 'InitFailure' state.
100103
// It will be removed from the cache later when the LRU queue is processed.
101-
auto Data = pDataWrpr->GetData(std::forward<InitDataType>(InitData), IsNewObject);
104+
DataType Data = pDataWrpr->GetData(std::forward<InitDataType>(InitData), IsNewObject);
102105

103106
// Process the release queue
104107
std::vector<std::shared_ptr<DataWrapper>> DeleteList;
@@ -115,7 +118,7 @@ class LRUCache
115118
if (it != m_Cache.end())
116119
{
117120
// Check that the object wrapper is the same.
118-
if (it->second == pDataWrpr)
121+
if (it->second.Wrpr == pDataWrpr)
119122
{
120123
// The wrapper is in the cache - label it as accounted and update the cache size.
121124

@@ -145,12 +148,9 @@ class LRUCache
145148
}
146149
}
147150

148-
for (int idx = static_cast<int>(m_LRUQueue.size()) - 1; idx >= 0; --idx)
151+
for (auto lru_it = m_LRU.begin(); lru_it != m_LRU.end() && m_CurrSize > m_MaxSize;)
149152
{
150-
if (m_CurrSize <= m_MaxSize)
151-
break;
152-
153-
VERIFY_EXPR(!m_LRUQueue.empty());
153+
const KeyType& EvictKey = *lru_it;
154154

155155
// State stransition table:
156156
// Protected by m_Mtx Accounted Size
@@ -160,12 +160,22 @@ class LRUCache
160160
// InitializedUnaccounted -> InitializedAccounted Yes !0 <U2A>
161161
// InitializedAccounted Final State
162162
//
163-
const auto& cache_it = m_LRUQueue[idx];
164-
const auto State = cache_it->second->GetState(); /* <ReadState> */
163+
const auto cache_it = m_Cache.find(EvictKey);
164+
if (cache_it == m_Cache.end())
165+
{
166+
UNEXPECTED("Unavailable key in LRU list. This should never happen.");
167+
lru_it = m_LRU.erase(lru_it);
168+
continue;
169+
}
170+
VERIFY_EXPR(cache_it->second.LRUIt == lru_it);
171+
172+
std::shared_ptr<DataWrapper>& pWrpr = cache_it->second.Wrpr;
173+
const DataWrapper::DataState State = pWrpr->GetState(); /* <ReadState> */
165174
if (State == DataWrapper::DataState::Default)
166175
{
167176
// The object is being initialized in another thread in DataWrapper::Get().
168177
// Possible actual states here are Default, InitializedUnaccounted or InitFailure.
178+
++lru_it;
169179
continue;
170180
}
171181
if (State == DataWrapper::DataState::InitializedUnaccounted)
@@ -174,6 +184,7 @@ class LRUCache
174184
// in the cache yet as this thread acquired the mutex first.
175185
// The only possible actual state here is InitializedUnaccounted as transition to
176186
// InitializedAccounted in <SA> requires mutex.
187+
++lru_it;
177188
continue;
178189
}
179190

@@ -191,19 +202,20 @@ class LRUCache
191202

192203
// NB: if the state was not InitializedAccounted when we read it in <ReadState>, it can't be
193204
// InitializedAccounted now since the transition <U2A> is protected by mutex in <SA>.
194-
VERIFY_EXPR((State == DataWrapper::DataState::InitializedAccounted && cache_it->second->GetState() == DataWrapper::DataState::InitializedAccounted) ||
195-
(State != DataWrapper::DataState::InitializedAccounted && cache_it->second->GetState() != DataWrapper::DataState::InitializedAccounted));
205+
VERIFY_EXPR((State == DataWrapper::DataState::InitializedAccounted && pWrpr->GetState() == DataWrapper::DataState::InitializedAccounted) ||
206+
(State != DataWrapper::DataState::InitializedAccounted && pWrpr->GetState() != DataWrapper::DataState::InitializedAccounted));
196207

197208
// Note that transition to InitializedAccounted state is protected by the mutex in <SA>, so
198209
// we can't remove a wrapper before it was accounted for.
199-
const size_t AccountedSize = cache_it->second->GetAccountedSize();
200-
DeleteList.emplace_back(std::move(cache_it->second));
210+
const size_t AccountedSize = pWrpr->GetAccountedSize();
211+
DeleteList.emplace_back(std::move(pWrpr));
201212
m_Cache.erase(cache_it); /* <Erase> */
202-
m_LRUQueue.erase(m_LRUQueue.begin() + idx);
213+
lru_it = m_LRU.erase(lru_it);
203214
VERIFY_EXPR(m_CurrSize >= AccountedSize);
204215
m_CurrSize -= AccountedSize;
205216
}
206-
VERIFY_EXPR(m_Cache.size() == m_LRUQueue.size());
217+
218+
VERIFY_EXPR(m_Cache.size() == m_LRU.size());
207219
}
208220

209221
// Delete objects after releasing the cache mutex
@@ -228,15 +240,19 @@ class LRUCache
228240
{
229241
#ifdef DILIGENT_DEBUG
230242
size_t DbgSize = 0;
231-
VERIFY_EXPR(m_Cache.size() == m_LRUQueue.size());
232-
while (!m_LRUQueue.empty())
243+
VERIFY_EXPR(m_Cache.size() == m_LRU.size());
244+
for (const KeyType& Key : m_LRU)
233245
{
234-
auto last_it = m_LRUQueue.back();
235-
m_LRUQueue.pop_back();
236-
DbgSize += last_it->second->GetAccountedSize();
237-
m_Cache.erase(last_it);
246+
auto it = m_Cache.find(Key);
247+
if (it != m_Cache.end())
248+
{
249+
DbgSize += it->second.Wrpr->GetAccountedSize();
250+
}
251+
else
252+
{
253+
UNEXPECTED("Unexpected key in LRU list");
254+
}
238255
}
239-
VERIFY_EXPR(m_Cache.empty());
240256
VERIFY_EXPR(DbgSize == m_CurrSize);
241257
#endif
242258
}
@@ -256,6 +272,16 @@ class LRUCache
256272
template <typename InitDataType>
257273
const DataType& GetData(InitDataType&& InitData, bool& IsNewObject) noexcept(false)
258274
{
275+
// Fast path
276+
{
277+
const DataState CurrentState = m_State.load();
278+
if (CurrentState == DataState::InitializedAccounted ||
279+
CurrentState == DataState::InitializedUnaccounted)
280+
{
281+
return m_Data;
282+
}
283+
}
284+
259285
std::lock_guard<std::mutex> Lock{m_InitDataMtx};
260286
if (m_DataSize == 0)
261287
{
@@ -320,28 +346,44 @@ class LRUCache
320346
auto it = m_Cache.find(Key);
321347
if (it == m_Cache.end())
322348
{
323-
it = m_Cache.emplace(Key, std::make_shared<DataWrapper>()).first;
349+
// Do the potentially-throwing allocations before modifying any cache state
350+
std::shared_ptr<DataWrapper> pWrpr = std::make_shared<DataWrapper>(); // May throw
351+
352+
m_LRU.push_back(Key);
353+
try
354+
{
355+
it = m_Cache.emplace(Key, Entry{std::move(pWrpr), std::prev(m_LRU.end())}).first;
356+
}
357+
catch (...)
358+
{
359+
m_LRU.pop_back();
360+
throw;
361+
}
324362
}
325363
else
326364
{
327-
// Pop the wrapper iterator from the queue
328-
auto queue_it = std::find(m_LRUQueue.begin(), m_LRUQueue.end(), it);
329-
VERIFY_EXPR(queue_it != m_LRUQueue.end());
330-
m_LRUQueue.erase(queue_it);
365+
// Move to MRU (back of the list)
366+
m_LRU.splice(m_LRU.end(), m_LRU, it->second.LRUIt);
331367
}
332368

333-
// Move iterator to the front of the queue
334-
m_LRUQueue.push_front(it);
335-
VERIFY_EXPR(m_Cache.size() == m_LRUQueue.size());
369+
VERIFY_EXPR(m_Cache.size() == m_LRU.size());
336370

337-
return it->second;
371+
return it->second.Wrpr;
338372
}
339373

340374

341-
using CacheType = std::unordered_map<KeyType, std::shared_ptr<DataWrapper>, KeyHasher>;
342-
CacheType m_Cache;
375+
using LRUList = std::list<KeyType>; // LRU at front, MRU at back
343376

344-
std::deque<typename CacheType::iterator> m_LRUQueue;
377+
struct Entry
378+
{
379+
std::shared_ptr<DataWrapper> Wrpr;
380+
typename LRUList::iterator LRUIt; // Stable iterator into list
381+
};
382+
383+
using CacheType = std::unordered_map<KeyType, Entry, KeyHasher>;
384+
385+
CacheType m_Cache;
386+
LRUList m_LRU;
345387

346388
std::mutex m_Mtx;
347389

0 commit comments

Comments
 (0)