Skip to content

Commit 211b9ae

Browse files
author
zourunxin.zrx
committed
hgraph search allocator2
1 parent 28b38b9 commit 211b9ae

File tree

1 file changed

+19
-17
lines changed

1 file changed

+19
-17
lines changed

src/impl/basic_searcher.cpp

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,11 @@ BasicSearcher::search_impl(const GraphInterfacePtr& graph,
100100
const float* query,
101101
const InnerSearchParam& inner_search_param,
102102
IteratorFilterContext* iter_ctx) const {
103-
vsag::Vector<std::pair<float, InnerIdType>> top_candidates_buffer(inner_search_param.search_alloc);
103+
Allocator *alloc = inner_search_param.search_alloc == nullptr ? allocator_ : inner_search_param.search_alloc;
104+
vsag::Vector<std::pair<float, InnerIdType>> top_candidates_buffer(alloc);
104105
top_candidates_buffer.reserve(inner_search_param.ef * 2);
105106
MaxHeap top_candidates(CompareByFirst(), top_candidates_buffer);
106-
vsag::Vector<std::pair<float, InnerIdType>> candidate_set_buffer(inner_search_param.search_alloc);
107+
vsag::Vector<std::pair<float, InnerIdType>> candidate_set_buffer(alloc);
107108
candidate_set_buffer.reserve(inner_search_param.ef * 2);
108109
MaxHeap candidate_set(CompareByFirst(), candidate_set_buffer);
109110

@@ -124,10 +125,10 @@ BasicSearcher::search_impl(const GraphInterfacePtr& graph,
124125
uint32_t hops = 0;
125126
uint32_t dist_cmp = 0;
126127
uint32_t count_no_visited = 0;
127-
Vector<InnerIdType> to_be_visited_rid(graph->MaximumDegree(), inner_search_param.search_alloc);
128-
Vector<InnerIdType> to_be_visited_id(graph->MaximumDegree(), inner_search_param.search_alloc);
129-
Vector<InnerIdType> neighbors(graph->MaximumDegree(), inner_search_param.search_alloc);
130-
Vector<float> line_dists(graph->MaximumDegree(), inner_search_param.search_alloc);
128+
Vector<InnerIdType> to_be_visited_rid(graph->MaximumDegree(), alloc);
129+
Vector<InnerIdType> to_be_visited_id(graph->MaximumDegree(), alloc);
130+
Vector<InnerIdType> neighbors(graph->MaximumDegree(), alloc);
131+
Vector<float> line_dists(graph->MaximumDegree(), alloc);
131132

132133
if (!iter_ctx->IsFirstUsed()) {
133134
if (iter_ctx->Empty()) {
@@ -139,7 +140,7 @@ BasicSearcher::search_impl(const GraphInterfacePtr& graph,
139140
if (!vl->Get(cur_inner_id) && iter_ctx->CheckPoint(cur_inner_id)) {
140141
vl->Set(cur_inner_id);
141142
lower_bound = std::max(lower_bound, cur_dist);
142-
flatten->Query(&cur_dist, computer, &cur_inner_id, 1, inner_search_param.search_alloc);
143+
flatten->Query(&cur_dist, computer, &cur_inner_id, 1, alloc);
143144
top_candidates.emplace(cur_dist, cur_inner_id);
144145
candidate_set.emplace(cur_dist, cur_inner_id);
145146
if constexpr (mode == InnerSearchMode::RANGE_SEARCH) {
@@ -151,7 +152,7 @@ BasicSearcher::search_impl(const GraphInterfacePtr& graph,
151152
iter_ctx->PopDiscard();
152153
}
153154
} else {
154-
flatten->Query(&dist, computer, &ep, 1, inner_search_param.search_alloc);
155+
flatten->Query(&dist, computer, &ep, 1, alloc);
155156
if (not is_id_allowed || is_id_allowed->CheckValid(ep)) {
156157
top_candidates.emplace(dist, ep);
157158
lower_bound = top_candidates.top().first;
@@ -186,7 +187,7 @@ BasicSearcher::search_impl(const GraphInterfacePtr& graph,
186187

187188
dist_cmp += count_no_visited;
188189

189-
flatten->Query(line_dists.data(), computer, to_be_visited_id.data(), count_no_visited, inner_search_param.search_alloc);
190+
flatten->Query(line_dists.data(), computer, to_be_visited_id.data(), count_no_visited, alloc);
190191

191192
for (uint32_t i = 0; i < count_no_visited; i++) {
192193
dist = line_dists[i];
@@ -238,10 +239,11 @@ BasicSearcher::search_impl(const GraphInterfacePtr& graph,
238239
const VisitedListPtr& vl,
239240
const float* query,
240241
const InnerSearchParam& inner_search_param) const {
241-
vsag::Vector<std::pair<float, InnerIdType>> top_candidates_buffer(inner_search_param.search_alloc);
242+
Allocator *alloc = inner_search_param.search_alloc == nullptr ? allocator_ : inner_search_param.search_alloc;
243+
vsag::Vector<std::pair<float, InnerIdType>> top_candidates_buffer(alloc);
242244
top_candidates_buffer.reserve(inner_search_param.ef * 2);
243245
MaxHeap top_candidates(CompareByFirst(), top_candidates_buffer);
244-
vsag::Vector<std::pair<float, InnerIdType>> candidate_set_buffer(inner_search_param.search_alloc);
246+
vsag::Vector<std::pair<float, InnerIdType>> candidate_set_buffer(alloc);
245247
candidate_set_buffer.reserve(inner_search_param.ef * 2);
246248
MaxHeap candidate_set(CompareByFirst(), candidate_set_buffer);
247249

@@ -261,12 +263,12 @@ BasicSearcher::search_impl(const GraphInterfacePtr& graph,
261263
uint32_t hops = 0;
262264
uint32_t dist_cmp = 0;
263265
uint32_t count_no_visited = 0;
264-
Vector<InnerIdType> to_be_visited_rid(graph->MaximumDegree(), inner_search_param.search_alloc);
265-
Vector<InnerIdType> to_be_visited_id(graph->MaximumDegree(), inner_search_param.search_alloc);
266-
Vector<InnerIdType> neighbors(graph->MaximumDegree(), inner_search_param.search_alloc);
267-
Vector<float> line_dists(graph->MaximumDegree(), inner_search_param.search_alloc);
266+
Vector<InnerIdType> to_be_visited_rid(graph->MaximumDegree(), alloc);
267+
Vector<InnerIdType> to_be_visited_id(graph->MaximumDegree(), alloc);
268+
Vector<InnerIdType> neighbors(graph->MaximumDegree(), alloc);
269+
Vector<float> line_dists(graph->MaximumDegree(), alloc);
268270

269-
flatten->Query(&dist, computer, &ep, 1, inner_search_param.search_alloc);
271+
flatten->Query(&dist, computer, &ep, 1, alloc);
270272
if (not is_id_allowed || is_id_allowed->CheckValid(ep)) {
271273
top_candidates.emplace(dist, ep);
272274
lower_bound = top_candidates.top().first;
@@ -305,7 +307,7 @@ BasicSearcher::search_impl(const GraphInterfacePtr& graph,
305307

306308
dist_cmp += count_no_visited;
307309

308-
flatten->Query(line_dists.data(), computer, to_be_visited_id.data(), count_no_visited, inner_search_param.search_alloc);
310+
flatten->Query(line_dists.data(), computer, to_be_visited_id.data(), count_no_visited, alloc);
309311

310312
for (uint32_t i = 0; i < count_no_visited; i++) {
311313
dist = line_dists[i];

0 commit comments

Comments
 (0)