Skip to content

Commit 758c595

Browse files
author
zourunxin.zrx
committed
search allocator
Signed-off-by: zourunxin.zrx <zourunxin.zrx@oceanbase.com>
1 parent f81606f commit 758c595

File tree

7 files changed

+40
-8
lines changed

7 files changed

+40
-8
lines changed

include/vsag/index.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,15 @@ class Index {
171171
throw std::runtime_error("Index doesn't support new filter");
172172
}
173173

174+
virtual tl::expected<DatasetPtr, Error>
175+
KnnSearch(const DatasetPtr& query,
176+
int64_t k,
177+
const std::string& parameters,
178+
const FilterPtr& filter,
179+
Allocator *allocator) const {
180+
throw std::runtime_error("Index doesn't support new filter");
181+
}
182+
174183
/**
175184
* @brief Performing single KNN search on index
176185
*

src/algorithm/hnswlib/algorithm_interface.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class AlgorithmInterface {
4646
size_t ef,
4747
const vsag::FilterPtr is_id_allowed = nullptr,
4848
float skip_ratio = 0.9f,
49+
vsag::Allocator *allocator = nullptr,
4950
vsag::IteratorFilterContext* iter_ctx = nullptr,
5051
bool is_last_filter = false) const = 0;
5152

src/algorithm/hnswlib/hnswalg.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -439,14 +439,16 @@ HierarchicalNSW::searchBaseLayerST(InnerIdType ep_id,
439439
size_t ef,
440440
const vsag::FilterPtr is_id_allowed,
441441
const float skip_ratio,
442+
vsag::Allocator *allocator,
442443
vsag::IteratorFilterContext* iter_ctx) const {
443444
vsag::LinearCongruentialGenerator generator;
444445
VisitedListPtr vl = visited_list_pool_->getFreeVisitedList();
445446
vl_type* visited_array = vl->mass;
446447
vl_type visited_array_tag = vl->curV;
448+
vsag::Allocator *search_allocator = allocator == nullptr ? allocator_ : allocator;
447449

448-
MaxHeap top_candidates(allocator_);
449-
MaxHeap candidate_set(allocator_);
450+
MaxHeap top_candidates(search_allocator);
451+
MaxHeap candidate_set(search_allocator);
450452

451453
float valid_ratio = is_id_allowed ? is_id_allowed->ValidRatio() : 1.0F;
452454
float skip_threshold = valid_ratio == 1.0F ? 0 : (1 - ((1 - valid_ratio) * skip_ratio));
@@ -1477,16 +1479,18 @@ HierarchicalNSW::searchKnn(const void* query_data,
14771479
uint64_t ef,
14781480
const vsag::FilterPtr is_id_allowed,
14791481
const float skip_ratio,
1482+
vsag::Allocator *allocator,
14801483
vsag::IteratorFilterContext* iter_ctx,
14811484
bool is_last_filter) const {
14821485
std::shared_lock resize_lock(resize_mutex_);
14831486
std::priority_queue<std::pair<float, LabelType>> result;
14841487
if (cur_element_count_ == 0)
14851488
return result;
14861489

1490+
vsag::Allocator *search_alloctor = allocator == nullptr ? allocator_ : allocator;
14871491
std::shared_ptr<float[]> normalize_query;
14881492
normalizeVector(query_data, normalize_query);
1489-
MaxHeap top_candidates(allocator_);
1493+
MaxHeap top_candidates(search_alloctor);
14901494
if (iter_ctx != nullptr && !iter_ctx->IsFirstUsed()) {
14911495
if (iter_ctx->Empty())
14921496
return result;
@@ -1504,6 +1508,7 @@ HierarchicalNSW::searchKnn(const void* query_data,
15041508
std::max(ef, k),
15051509
is_id_allowed,
15061510
skip_ratio,
1511+
allocator,
15071512
iter_ctx);
15081513
} else {
15091514
int64_t currObj;
@@ -1545,10 +1550,10 @@ HierarchicalNSW::searchKnn(const void* query_data,
15451550

15461551
if (num_deleted_ == 0) {
15471552
top_candidates = searchBaseLayerST<false, true>(
1548-
currObj, query_data, std::max(ef, k), is_id_allowed, skip_ratio, iter_ctx);
1553+
currObj, query_data, std::max(ef, k), is_id_allowed, skip_ratio, allocator, iter_ctx);
15491554
} else {
15501555
top_candidates = searchBaseLayerST<true, true>(
1551-
currObj, query_data, std::max(ef, k), is_id_allowed, skip_ratio, iter_ctx);
1556+
currObj, query_data, std::max(ef, k), is_id_allowed, skip_ratio, allocator, iter_ctx);
15521557
}
15531558
}
15541559

@@ -1669,6 +1674,7 @@ HierarchicalNSW::searchBaseLayerST<false, false>(
16691674
size_t ef,
16701675
const vsag::FilterPtr is_id_allowed,
16711676
const float skip_ratio,
1677+
vsag::Allocator *allocator,
16721678
vsag::IteratorFilterContext* iter_ctx = nullptr) const;
16731679

16741680
template MaxHeap

src/algorithm/hnswlib/hnswalg.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ class HierarchicalNSW : public AlgorithmInterface<float> {
255255
size_t ef,
256256
const vsag::FilterPtr is_id_allowed = nullptr,
257257
const float skip_ratio = 0.9f,
258+
vsag::Allocator *allocator = nullptr,
258259
vsag::IteratorFilterContext* iter_ctx = nullptr) const;
259260

260261
template <bool has_deletions, bool collect_metrics = false>
@@ -411,6 +412,7 @@ class HierarchicalNSW : public AlgorithmInterface<float> {
411412
uint64_t ef,
412413
const vsag::FilterPtr is_id_allowed = nullptr,
413414
const float skip_ratio = 0.9f,
415+
vsag::Allocator *allocator = nullptr,
414416
vsag::IteratorFilterContext* iter_ctx = nullptr,
415417
bool is_last_filter = false) const override;
416418

src/algorithm/hnswlib/hnswalg_static.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1467,6 +1467,7 @@ class StaticHierarchicalNSW : public AlgorithmInterface<float> {
14671467
uint64_t ef,
14681468
const vsag::FilterPtr is_id_allowed = nullptr,
14691469
const float skip_ratio = 0.9f,
1470+
vsag::Allocator *allocator = nullptr,
14701471
vsag::IteratorFilterContext* iter_ctx = nullptr,
14711472
bool is_last_filter = false) const override {
14721473
std::priority_queue<std::pair<float, LabelType>> result;

src/index/hnsw.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ HNSW::knn_search(const DatasetPtr& query,
204204
int64_t k,
205205
const std::string& parameters,
206206
const FilterPtr& filter_ptr,
207+
vsag::Allocator *allocator,
207208
vsag::IteratorContext** iter_ctx,
208209
bool is_last_filter) const {
209210
#ifndef ENABLE_TESTS
@@ -216,6 +217,7 @@ HNSW::knn_search(const DatasetPtr& query,
216217
ret->Dim(0)->NumElements(1);
217218
return ret;
218219
}
220+
vsag::Allocator *search_alloctor = allocator == nullptr ? allocator_.get() : allocator;
219221

220222
// check query vector
221223
CHECK_ARGUMENT(query->GetNumElements() == 1, "query dataset should contain 1 vector only");
@@ -238,7 +240,7 @@ HNSW::knn_search(const DatasetPtr& query,
238240

239241
if (iter_ctx != nullptr && *iter_ctx == nullptr) {
240242
auto* filter_context = new IteratorFilterContext();
241-
filter_context->init(alg_hnsw_->getMaxElements(), params.ef_search, allocator_.get());
243+
filter_context->init(alg_hnsw_->getMaxElements(), params.ef_search, search_alloctor);
242244
*iter_ctx = filter_context;
243245
}
244246
IteratorFilterContext* iter_filter_ctx = nullptr;
@@ -260,6 +262,7 @@ HNSW::knn_search(const DatasetPtr& query,
260262
std::max(params.ef_search, k),
261263
filter_ptr,
262264
params.skip_ratio,
265+
allocator,
263266
iter_filter_ctx,
264267
is_last_filter);
265268
} catch (const std::runtime_error& e) {
@@ -299,7 +302,7 @@ HNSW::knn_search(const DatasetPtr& query,
299302
results.pop();
300303
}
301304
auto [dataset_results, dists, ids] =
302-
CreateFastDataset(static_cast<int64_t>(results.size()), allocator_.get());
305+
CreateFastDataset(static_cast<int64_t>(results.size()), search_alloctor);
303306

304307
for (auto j = static_cast<int64_t>(results.size() - 1); j >= 0; --j) {
305308
dists[j] = results.top().first;

src/index/hnsw.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,15 @@ class HNSW : public Index {
111111
SAFE_CALL(return this->knn_search(query, k, parameters, filter));
112112
}
113113

114+
tl::expected<DatasetPtr, Error>
115+
KnnSearch(const DatasetPtr& query,
116+
int64_t k,
117+
const std::string& parameters,
118+
const FilterPtr& filter,
119+
Allocator *allocator) const override {
120+
SAFE_CALL(return this->knn_search(query, k, parameters, filter, allocator));
121+
}
122+
114123
tl::expected<DatasetPtr, Error>
115124
KnnSearch(const DatasetPtr& query,
116125
int64_t k,
@@ -119,7 +128,7 @@ class HNSW : public Index {
119128
vsag::IteratorContext*& filter_ctx,
120129
bool is_last_search) const override {
121130
SAFE_CALL(
122-
return this->knn_search(query, k, parameters, filter, &filter_ctx, is_last_search));
131+
return this->knn_search(query, k, parameters, filter, nullptr, &filter_ctx, is_last_search));
123132
}
124133

125134
tl::expected<DatasetPtr, Error>
@@ -282,6 +291,7 @@ class HNSW : public Index {
282291
int64_t k,
283292
const std::string& parameters,
284293
const FilterPtr& filter_ptr,
294+
vsag::Allocator *allocator = nullptr,
285295
vsag::IteratorContext** iter_ctx = nullptr,
286296
bool is_last_filter = false) const;
287297

0 commit comments

Comments
 (0)