Skip to content

Commit 28b38b9

Browse files
author
zourunxin.zrx
committed
hgraph search allocator
1 parent 758c595 commit 28b38b9

File tree

9 files changed

+96
-33
lines changed

9 files changed

+96
-33
lines changed

src/algorithm/hgraph.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,17 @@ HGraph::KnnSearch(const DatasetPtr& query,
188188
int64_t k,
189189
const std::string& parameters,
190190
const FilterPtr& filter) const {
191+
return KnnSearch(query, k, parameters, filter, nullptr);
192+
}
193+
194+
DatasetPtr
195+
HGraph::KnnSearch(const DatasetPtr& query,
196+
int64_t k,
197+
const std::string& parameters,
198+
const FilterPtr& filter,
199+
Allocator *allocator) const {
191200
int64_t query_dim = query->GetDim();
201+
Allocator *search_allocator = allocator == nullptr ? allocator_ : allocator;
192202
CHECK_ARGUMENT(query_dim == dim_,
193203
fmt::format("query.dim({}) must be equal to index.dim({})", query_dim, dim_));
194204
// check k
@@ -203,6 +213,7 @@ HGraph::KnnSearch(const DatasetPtr& query,
203213
search_param.topk = 1;
204214
search_param.ef = 1;
205215
search_param.is_inner_id_allowed = nullptr;
216+
search_param.search_alloc = search_allocator;
206217
for (auto i = static_cast<int64_t>(this->route_graphs_.size() - 1); i >= 0; --i) {
207218
auto result = this->search_one_graph(query->GetFloat32Vectors(),
208219
this->route_graphs_[i],
@@ -240,10 +251,10 @@ HGraph::KnnSearch(const DatasetPtr& query,
240251
return DatasetImpl::MakeEmptyDataset();
241252
}
242253
auto count = static_cast<const int64_t>(search_result.size());
243-
auto [dataset_results, dists, ids] = CreateFastDataset(count, allocator_);
254+
auto [dataset_results, dists, ids] = CreateFastDataset(count, search_allocator);
244255
char* extra_infos = nullptr;
245256
if (extra_info_size_ > 0) {
246-
extra_infos = (char*)allocator_->Allocate(extra_info_size_ * search_result.size());
257+
extra_infos = (char*)search_allocator->Allocate(extra_info_size_ * search_result.size());
247258
dataset_results->ExtraInfos(extra_infos);
248259
}
249260
for (int64_t j = count - 1; j >= 0; --j) {
@@ -315,6 +326,7 @@ HGraph::KnnSearch(const DatasetPtr& query,
315326
search_param.topk = 1;
316327
search_param.ef = 1;
317328
search_param.is_inner_id_allowed = nullptr;
329+
search_param.search_alloc = allocator_;
318330
if (iter_filter_ctx->IsFirstUsed()) {
319331
for (auto i = static_cast<int64_t>(this->route_graphs_.size() - 1); i >= 0; --i) {
320332
auto result = this->search_one_graph(query->GetFloat32Vectors(),

src/algorithm/hgraph.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,13 @@ class HGraph : public InnerIndexInterface {
7878
const std::string& parameters,
7979
const FilterPtr& filter) const override;
8080

81+
[[nodiscard]] DatasetPtr
82+
KnnSearch(const DatasetPtr& query,
83+
int64_t k,
84+
const std::string& parameters,
85+
const FilterPtr& filter,
86+
Allocator *allocator) const override;
87+
8188
[[nodiscard]] DatasetPtr
8289
KnnSearch(const DatasetPtr& query,
8390
int64_t k,

src/algorithm/hnswlib/hnswalg.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -447,8 +447,12 @@ HierarchicalNSW::searchBaseLayerST(InnerIdType ep_id,
447447
vl_type visited_array_tag = vl->curV;
448448
vsag::Allocator *search_allocator = allocator == nullptr ? allocator_ : allocator;
449449

450-
MaxHeap top_candidates(search_allocator);
451-
MaxHeap candidate_set(search_allocator);
450+
vsag::Vector<std::pair<float, InnerIdType>> top_candidates_buffer(search_allocator);
451+
top_candidates_buffer.reserve(ef * 2);
452+
MaxHeap top_candidates(CompareByFirst(), top_candidates_buffer);
453+
vsag::Vector<std::pair<float, InnerIdType>> candidate_set_buffer(search_allocator);
454+
candidate_set_buffer.reserve(ef * 2);
455+
MaxHeap candidate_set(CompareByFirst(), candidate_set_buffer);
452456

453457
float valid_ratio = is_id_allowed ? is_id_allowed->ValidRatio() : 1.0F;
454458
float skip_threshold = valid_ratio == 1.0F ? 0 : (1 - ((1 - valid_ratio) * skip_ratio));
@@ -1490,7 +1494,10 @@ HierarchicalNSW::searchKnn(const void* query_data,
14901494
vsag::Allocator *search_alloctor = allocator == nullptr ? allocator_ : allocator;
14911495
std::shared_ptr<float[]> normalize_query;
14921496
normalizeVector(query_data, normalize_query);
1493-
MaxHeap top_candidates(search_alloctor);
1497+
vsag::Vector<std::pair<float, InnerIdType>> top_candidates_buffer(search_alloctor);
1498+
top_candidates_buffer.reserve(std::max(ef, k) * 2);
1499+
MaxHeap top_candidates(CompareByFirst(), top_candidates_buffer);
1500+
vsag::Vector<std::pair<float, InnerIdType>> candidate_set_buffer(search_alloctor);
14941501
if (iter_ctx != nullptr && !iter_ctx->IsFirstUsed()) {
14951502
if (iter_ctx->Empty())
14961503
return result;

src/algorithm/inner_index_interface.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,15 @@ class InnerIndexInterface {
8282
const std::string& parameters,
8383
const std::function<bool(int64_t)>& filter) const;
8484

85+
[[nodiscard]] virtual DatasetPtr
86+
KnnSearch(const DatasetPtr& query,
87+
int64_t k,
88+
const std::string& parameters,
89+
const FilterPtr& filter,
90+
Allocator *allocator) const {
91+
throw std::runtime_error("Index doesn't support new filter");
92+
};
93+
8594
[[nodiscard]] virtual DatasetPtr
8695
KnnSearch(const DatasetPtr& query,
8796
int64_t k,

src/data_cell/flatten_datacell.h

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,10 @@ class FlattenDataCell : public FlattenInterface {
4242
Query(float* result_dists,
4343
const ComputerInterfacePtr& computer,
4444
const InnerIdType* idx,
45-
InnerIdType id_count) override {
45+
InnerIdType id_count,
46+
Allocator *allocator = nullptr) override {
4647
auto comp = std::static_pointer_cast<Computer<QuantTmpl>>(computer);
47-
this->query(result_dists, comp, idx, id_count);
48+
this->query(result_dists, comp, idx, id_count, allocator);
4849
}
4950

5051
ComputerInterfacePtr
@@ -142,13 +143,15 @@ class FlattenDataCell : public FlattenInterface {
142143
query(float* result_dists,
143144
const float* query_vector,
144145
const InnerIdType* idx,
145-
InnerIdType id_count);
146+
InnerIdType id_count,
147+
Allocator* allocator);
146148

147149
inline void
148150
query(float* result_dists,
149151
const std::shared_ptr<Computer<QuantTmpl>>& computer,
150152
const InnerIdType* idx,
151-
InnerIdType id_count);
153+
InnerIdType id_count,
154+
Allocator* allocator);
152155

153156
ComputerInterfacePtr
154157
factory_computer(const float* query) {
@@ -305,18 +308,21 @@ void
305308
FlattenDataCell<QuantTmpl, IOTmpl>::query(float* result_dists,
306309
const float* query_vector,
307310
const InnerIdType* idx,
308-
InnerIdType id_count) {
311+
InnerIdType id_count,
312+
Allocator* allocator) {
309313
auto computer = quantizer_->FactoryComputer();
310314
computer->SetQuery(query_vector);
311-
this->Query(result_dists, computer, idx, id_count);
315+
this->Query(result_dists, computer, idx, id_count, allocator);
312316
}
313317

314318
template <typename QuantTmpl, typename IOTmpl>
315319
void
316320
FlattenDataCell<QuantTmpl, IOTmpl>::query(float* result_dists,
317321
const std::shared_ptr<Computer<QuantTmpl>>& computer,
318322
const InnerIdType* idx,
319-
InnerIdType id_count) {
323+
InnerIdType id_count,
324+
Allocator* allocator) {
325+
Allocator *search_alloc = allocator == nullptr ? allocator_ : allocator;
320326
for (uint32_t i = 0; i < this->prefetch_jump_code_size_ and i < id_count; i++) {
321327
if (force_in_memory_) {
322328
this->force_in_memory_io_->Prefetch(
@@ -328,9 +334,9 @@ FlattenDataCell<QuantTmpl, IOTmpl>::query(float* result_dists,
328334
}
329335
}
330336
if (not force_in_memory_ and not this->io_->InMemory() and id_count > 1) {
331-
ByteBuffer codes(id_count * this->code_size_, allocator_);
332-
Vector<uint64_t> sizes(id_count, this->code_size_, allocator_);
333-
Vector<uint64_t> offsets(id_count, this->code_size_, allocator_);
337+
ByteBuffer codes(id_count * this->code_size_, search_alloc);
338+
Vector<uint64_t> sizes(id_count, this->code_size_, search_alloc);
339+
Vector<uint64_t> offsets(id_count, this->code_size_, search_alloc);
334340
for (int64_t i = 0; i < id_count; ++i) {
335341
offsets[i] = idx[i] * code_size_;
336342
}

src/data_cell/flatten_interface.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ class FlattenInterface {
4141
Query(float* result_dists,
4242
const ComputerInterfacePtr& computer,
4343
const InnerIdType* idx,
44-
InnerIdType id_count) = 0;
44+
InnerIdType id_count,
45+
Allocator *allocator = nullptr) = 0;
4546

4647
virtual ComputerInterfacePtr
4748
FactoryComputer(const float* query) = 0;

src/impl/basic_searcher.cpp

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,12 @@ BasicSearcher::search_impl(const GraphInterfacePtr& graph,
100100
const float* query,
101101
const InnerSearchParam& inner_search_param,
102102
IteratorFilterContext* iter_ctx) const {
103-
MaxHeap top_candidates(allocator_);
104-
MaxHeap candidate_set(allocator_);
103+
vsag::Vector<std::pair<float, InnerIdType>> top_candidates_buffer(inner_search_param.search_alloc);
104+
top_candidates_buffer.reserve(inner_search_param.ef * 2);
105+
MaxHeap top_candidates(CompareByFirst(), top_candidates_buffer);
106+
vsag::Vector<std::pair<float, InnerIdType>> candidate_set_buffer(inner_search_param.search_alloc);
107+
candidate_set_buffer.reserve(inner_search_param.ef * 2);
108+
MaxHeap candidate_set(CompareByFirst(), candidate_set_buffer);
105109

106110
if (not graph or not flatten) {
107111
return top_candidates;
@@ -120,10 +124,10 @@ BasicSearcher::search_impl(const GraphInterfacePtr& graph,
120124
uint32_t hops = 0;
121125
uint32_t dist_cmp = 0;
122126
uint32_t count_no_visited = 0;
123-
Vector<InnerIdType> to_be_visited_rid(graph->MaximumDegree(), allocator_);
124-
Vector<InnerIdType> to_be_visited_id(graph->MaximumDegree(), allocator_);
125-
Vector<InnerIdType> neighbors(graph->MaximumDegree(), allocator_);
126-
Vector<float> line_dists(graph->MaximumDegree(), allocator_);
127+
Vector<InnerIdType> to_be_visited_rid(graph->MaximumDegree(), inner_search_param.search_alloc);
128+
Vector<InnerIdType> to_be_visited_id(graph->MaximumDegree(), inner_search_param.search_alloc);
129+
Vector<InnerIdType> neighbors(graph->MaximumDegree(), inner_search_param.search_alloc);
130+
Vector<float> line_dists(graph->MaximumDegree(), inner_search_param.search_alloc);
127131

128132
if (!iter_ctx->IsFirstUsed()) {
129133
if (iter_ctx->Empty()) {
@@ -135,7 +139,7 @@ BasicSearcher::search_impl(const GraphInterfacePtr& graph,
135139
if (!vl->Get(cur_inner_id) && iter_ctx->CheckPoint(cur_inner_id)) {
136140
vl->Set(cur_inner_id);
137141
lower_bound = std::max(lower_bound, cur_dist);
138-
flatten->Query(&cur_dist, computer, &cur_inner_id, 1);
142+
flatten->Query(&cur_dist, computer, &cur_inner_id, 1, inner_search_param.search_alloc);
139143
top_candidates.emplace(cur_dist, cur_inner_id);
140144
candidate_set.emplace(cur_dist, cur_inner_id);
141145
if constexpr (mode == InnerSearchMode::RANGE_SEARCH) {
@@ -147,7 +151,7 @@ BasicSearcher::search_impl(const GraphInterfacePtr& graph,
147151
iter_ctx->PopDiscard();
148152
}
149153
} else {
150-
flatten->Query(&dist, computer, &ep, 1);
154+
flatten->Query(&dist, computer, &ep, 1, inner_search_param.search_alloc);
151155
if (not is_id_allowed || is_id_allowed->CheckValid(ep)) {
152156
top_candidates.emplace(dist, ep);
153157
lower_bound = top_candidates.top().first;
@@ -182,7 +186,7 @@ BasicSearcher::search_impl(const GraphInterfacePtr& graph,
182186

183187
dist_cmp += count_no_visited;
184188

185-
flatten->Query(line_dists.data(), computer, to_be_visited_id.data(), count_no_visited);
189+
flatten->Query(line_dists.data(), computer, to_be_visited_id.data(), count_no_visited, inner_search_param.search_alloc);
186190

187191
for (uint32_t i = 0; i < count_no_visited; i++) {
188192
dist = line_dists[i];
@@ -234,8 +238,12 @@ BasicSearcher::search_impl(const GraphInterfacePtr& graph,
234238
const VisitedListPtr& vl,
235239
const float* query,
236240
const InnerSearchParam& inner_search_param) const {
237-
MaxHeap top_candidates(allocator_);
238-
MaxHeap candidate_set(allocator_);
241+
vsag::Vector<std::pair<float, InnerIdType>> top_candidates_buffer(inner_search_param.search_alloc);
242+
top_candidates_buffer.reserve(inner_search_param.ef * 2);
243+
MaxHeap top_candidates(CompareByFirst(), top_candidates_buffer);
244+
vsag::Vector<std::pair<float, InnerIdType>> candidate_set_buffer(inner_search_param.search_alloc);
245+
candidate_set_buffer.reserve(inner_search_param.ef * 2);
246+
MaxHeap candidate_set(CompareByFirst(), candidate_set_buffer);
239247

240248
if (not graph or not flatten) {
241249
return top_candidates;
@@ -253,12 +261,12 @@ BasicSearcher::search_impl(const GraphInterfacePtr& graph,
253261
uint32_t hops = 0;
254262
uint32_t dist_cmp = 0;
255263
uint32_t count_no_visited = 0;
256-
Vector<InnerIdType> to_be_visited_rid(graph->MaximumDegree(), allocator_);
257-
Vector<InnerIdType> to_be_visited_id(graph->MaximumDegree(), allocator_);
258-
Vector<InnerIdType> neighbors(graph->MaximumDegree(), allocator_);
259-
Vector<float> line_dists(graph->MaximumDegree(), allocator_);
264+
Vector<InnerIdType> to_be_visited_rid(graph->MaximumDegree(), inner_search_param.search_alloc);
265+
Vector<InnerIdType> to_be_visited_id(graph->MaximumDegree(), inner_search_param.search_alloc);
266+
Vector<InnerIdType> neighbors(graph->MaximumDegree(), inner_search_param.search_alloc);
267+
Vector<float> line_dists(graph->MaximumDegree(), inner_search_param.search_alloc);
260268

261-
flatten->Query(&dist, computer, &ep, 1);
269+
flatten->Query(&dist, computer, &ep, 1, inner_search_param.search_alloc);
262270
if (not is_id_allowed || is_id_allowed->CheckValid(ep)) {
263271
top_candidates.emplace(dist, ep);
264272
lower_bound = top_candidates.top().first;
@@ -297,7 +305,7 @@ BasicSearcher::search_impl(const GraphInterfacePtr& graph,
297305

298306
dist_cmp += count_no_visited;
299307

300-
flatten->Query(line_dists.data(), computer, to_be_visited_id.data(), count_no_visited);
308+
flatten->Query(line_dists.data(), computer, to_be_visited_id.data(), count_no_visited, inner_search_param.search_alloc);
301309

302310
for (uint32_t i = 0; i < count_no_visited; i++) {
303311
dist = line_dists[i];

src/impl/basic_searcher.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class InnerSearchParam {
3838
float skip_ratio{0.8F};
3939
InnerSearchMode search_mode{KNN_SEARCH};
4040
int range_search_limit_size{-1};
41+
Allocator *search_alloc;
4142
};
4243

4344
constexpr float THRESHOLD_ERROR = 2e-6;

src/index/index_impl.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,18 @@ class IndexImpl : public Index {
107107
SAFE_CALL(return this->inner_index_->KnnSearch(query, k, parameters, filter));
108108
}
109109

110+
tl::expected<DatasetPtr, Error>
111+
KnnSearch(const DatasetPtr& query,
112+
int64_t k,
113+
const std::string& parameters,
114+
const FilterPtr& filter,
115+
Allocator *allocator) const override {
116+
if (GetNumElements() == 0) {
117+
return DatasetImpl::MakeEmptyDataset();
118+
}
119+
SAFE_CALL(return this->inner_index_->KnnSearch(query, k, parameters, filter, allocator));
120+
}
121+
110122
tl::expected<DatasetPtr, Error>
111123
KnnSearch(const DatasetPtr& query,
112124
int64_t k,

0 commit comments

Comments
 (0)