Skip to content

Commit 6f425cc

Browse files
author
Qi Chen
committed
fix merge/get deadlock issue
1 parent 734a845 commit 6f425cc

File tree

1 file changed

+21
-33
lines changed

1 file changed

+21
-33
lines changed

AnnService/inc/Core/SPANN/ExtraFileController.h

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "inc/Core/Common/Dataset.h"
77
#include "inc/Core/VectorIndex.h"
88
#include "inc/Helper/ThreadPool.h"
9+
#include "inc/Helper/ConcurrentSet.h"
910
#include "inc/Helper/AsyncFileReader.h"
1011
#include "inc/Core/SPANN/Options.h"
1112
#include <cstdlib>
@@ -196,10 +197,10 @@ namespace SPTAG::SPANN {
196197
int limit;
197198
std::uint64_t size;
198199
std::list<SizeType> keys; // Page Address
199-
std::unordered_map<SizeType, std::pair<std::string, std::list<SizeType>::iterator>> cache; // Page Address -> Page Address in Cache
200-
std::mutex mu;
200+
Helper::Concurrent::ConcurrentMap<SizeType, std::pair<std::string, std::list<SizeType>::iterator>> cache; // Page Address -> Page Address in Cache
201+
std::shared_timed_mutex mu;
201202
int64_t queries;
202-
int64_t hits;
203+
std::atomic<int64_t> hits;
203204
FileIO* fileIO;
204205
std::vector<Helper::AsyncReadRequest> reqs;
205206

@@ -230,89 +231,77 @@ namespace SPTAG::SPANN {
230231
}
231232

232233
bool get(SizeType key, void* value) {
233-
mu.lock();
234+
std::shared_lock<std::shared_timed_mutex> lock(mu);
234235
queries++;
235236
auto it = cache.find(key);
236237
if (it == cache.end()) {
237-
mu.unlock();
238238
return false; // If the key does not exist, return -1
239239
}
240240
// Update access order, move the key to the head of the linked list
241241
memcpy(value, it->second.first.data(), it->second.first.size());
242-
keys.splice(keys.begin(), keys, it->second.second);
243-
it->second.second = keys.begin();
244242
hits++;
245-
mu.unlock();
246243
return true;
247244
}
248245

249246
bool put(SizeType key, void* value, int put_size) {
250-
mu.lock();
247+
std::unique_lock<std::shared_timed_mutex> lock(mu);
251248
auto it = cache.find(key);
252249
if (it != cache.end()) {
253250
if (put_size > limit) {
254251
evict(key, it->second.first.data(), it->second.first.size(), it);
255-
mu.unlock();
256252
return false;
257253
}
258254
keys.splice(keys.begin(), keys, it->second.second);
259255
it->second.second = keys.begin();
256+
260257
auto delta_size = put_size - it->second.first.size();
261258
while ((capacity - size) < delta_size && (keys.size() > 1)) {
262259
auto last = keys.back();
263260
auto lastit = cache.find(last);
264261
if (!evict(last, lastit->second.first.data(), lastit->second.first.size(), lastit)) {
265-
mu.unlock();
266262
return false;
267263
}
268264
}
269265
it->second.first.resize(put_size);
270266
memcpy(it->second.first.data(), value, put_size);
271267
size += delta_size;
272-
mu.unlock();
273268
return true;
274269
}
275270
if (put_size > limit) {
276-
mu.unlock();
277271
return false;
278272
}
279273
while (put_size > (capacity - size) && (!keys.empty())) {
280274
auto last = keys.back();
281275
auto lastit = cache.find(last);
282276
if (!evict(last, lastit->second.first.data(), lastit->second.first.size(), lastit)) {
283-
mu.unlock();
284277
return false;
285278
}
286279
}
287280
auto keys_it = keys.insert(keys.begin(), key);
288281
cache.insert({key, {std::string((char*)value, put_size), keys_it}});
289282
size += put_size;
290-
mu.unlock();
291283
return true;
292284
}
293285

294286
bool del(SizeType key) {
295-
mu.lock();
287+
std::unique_lock<std::shared_timed_mutex> lock(mu);
296288
auto it = cache.find(key);
297289
if (it == cache.end()) {
298-
mu.unlock();
299290
return false; // If the key does not exist, return false
300291
}
301292
evict(key, nullptr, 0, it);
302-
mu.unlock();
303293
return true;
304294
}
305295

306296
bool merge(SizeType key, void* value, AddressType merge_size) {
307-
mu.lock();
297+
std::unique_lock<std::shared_timed_mutex> lock(mu);
308298
// SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "LRUCache: merge size: %lld\n", merge_size);
309299
auto it = cache.find(key);
310300
if (it == cache.end()) {
311301
// SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "LRUCache: merge key not found\n");
312302
std::string valstr;
313-
if (fileIO->Get(key, &valstr, MaxTimeout, &reqs) != ErrorCode::Success) {
303+
if (fileIO->Get(key, &valstr, MaxTimeout, &reqs, false) != ErrorCode::Success) {
314304
SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "LRUCache: merge key not found in file\n");
315-
mu.unlock();
316305
return false; // If the key does not exist, return false
317306
}
318307
cache.insert({key, {valstr, keys.insert(keys.begin(), key)}});
@@ -323,7 +312,6 @@ namespace SPTAG::SPANN {
323312
if (merge_size + it->second.first.size() > limit) {
324313
evict(key, it->second.first.data(), it->second.first.size(), it);
325314
// SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "LRUCache: merge size exceeded\n");
326-
mu.unlock();
327315
return false;
328316
}
329317
keys.splice(keys.begin(), keys, it->second.second);
@@ -332,34 +320,30 @@ namespace SPTAG::SPANN {
332320
auto last = keys.back();
333321
auto lastit = cache.find(last);
334322
if (!evict(last, lastit->second.first.data(), lastit->second.first.size(), lastit)) {
335-
mu.unlock();
336323
return false;
337324
}
338325
}
339326
it->second.first.append((char*)value, merge_size);
340327
size += merge_size;
341328
// SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "LRUCache: merge success\n");
342-
mu.unlock();
343329
return true;
344330
}
345331

346332
std::pair<int64_t, int64_t> get_stat() {
347-
return {queries, hits};
333+
return {queries, hits.load()};
348334
}
349335

350336
bool flush() {
351-
mu.lock();
337+
std::unique_lock<std::shared_timed_mutex> lock(mu);
352338
for (auto it = cache.begin(); it != cache.end(); it++) {
353339
if (fileIO->Put(it->first, it->second.first, MaxTimeout, &reqs, false) != ErrorCode::Success) {
354340
SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "LRUCache: evict key:%d value size:%d to file failed\n", it->first, (int)(it->second.first.size()));
355-
mu.unlock();
356341
return false;
357342
}
358343
}
359344
cache.clear();
360345
keys.clear();
361346
size = 0;
362-
mu.unlock();
363347
return true;
364348
}
365349
};
@@ -551,7 +535,7 @@ namespace SPTAG::SPANN {
551535
return *(m_pBlockMapping[key]);
552536
}
553537

554-
ErrorCode Get(const SizeType key, std::string* value, const std::chrono::microseconds &timeout, std::vector<Helper::AsyncReadRequest>* reqs) override {
538+
ErrorCode Get(const SizeType key, std::string* value, const std::chrono::microseconds &timeout, std::vector<Helper::AsyncReadRequest>* reqs, bool useCache) {
555539
auto get_begin_time = std::chrono::high_resolution_clock::now();
556540
if (m_fileIoUseLock) {
557541
m_rwMutex[hash(key)].lock_shared();
@@ -578,7 +562,7 @@ namespace SPTAG::SPANN {
578562
auto size = addr[0];
579563
if (size < 0) return ErrorCode::Posting_SizeError;
580564

581-
if (m_pShardedLRUCache) {
565+
if (useCache && m_pShardedLRUCache) {
582566
value->resize(size);
583567
if (m_pShardedLRUCache->get(key, value->data())) {
584568
return ErrorCode::Success;
@@ -601,8 +585,12 @@ namespace SPTAG::SPANN {
601585
return result ? ErrorCode::Success : ErrorCode::Fail;
602586
}
603587

588+
ErrorCode Get(const SizeType key, std::string* value, const std::chrono::microseconds &timeout, std::vector<Helper::AsyncReadRequest>* reqs) override {
589+
return Get(key, value, timeout, reqs, true);
590+
}
591+
604592
ErrorCode Get(const std::string& key, std::string* value, const std::chrono::microseconds& timeout, std::vector<Helper::AsyncReadRequest>* reqs) override {
605-
return Get(std::stoi(key), value, timeout, reqs);
593+
return Get(std::stoi(key), value, timeout, reqs, true);
606594
}
607595

608596
ErrorCode MultiGet(const std::vector<SizeType>& keys, std::vector<Helper::PageBuffer<std::uint8_t>>& values,
@@ -753,7 +741,7 @@ namespace SPTAG::SPANN {
753741
}
754742
*/
755743

756-
ErrorCode Put(const SizeType key, const std::string& value, const std::chrono::microseconds& timeout, std::vector<Helper::AsyncReadRequest>* reqs, bool cache) {
744+
ErrorCode Put(const SizeType key, const std::string& value, const std::chrono::microseconds& timeout, std::vector<Helper::AsyncReadRequest>* reqs, bool useCache) {
757745
int blocks = (int)(((value.size() + PageSize - 1) >> PageSizeEx));
758746
if (blocks >= m_blockLimit) {
759747
SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to put key:%d value:%lld since value too long!\n", key, value.size());
@@ -782,7 +770,7 @@ namespace SPTAG::SPANN {
782770
m_updateMutex.unlock();
783771
}
784772

785-
if (cache && m_pShardedLRUCache) {
773+
if (useCache && m_pShardedLRUCache) {
786774
if (m_pShardedLRUCache->put(key, (void*)(value.data()), (SPTAG::SizeType)(value.size()))) {
787775
return ErrorCode::Success;
788776
}

0 commit comments

Comments
 (0)