@@ -439,14 +439,16 @@ HierarchicalNSW::searchBaseLayerST(InnerIdType ep_id,
439439 size_t ef,
440440 const vsag::FilterPtr is_id_allowed,
441441 const float skip_ratio,
442+ vsag::Allocator *allocator,
442443 vsag::IteratorFilterContext* iter_ctx) const {
443444 vsag::LinearCongruentialGenerator generator;
444445 VisitedListPtr vl = visited_list_pool_->getFreeVisitedList ();
445446 vl_type* visited_array = vl->mass ;
446447 vl_type visited_array_tag = vl->curV ;
448+ vsag::Allocator *search_allocator = allocator == nullptr ? allocator_ : allocator;
447449
448- MaxHeap top_candidates (allocator_ );
449- MaxHeap candidate_set (allocator_ );
450+ MaxHeap top_candidates (search_allocator );
451+ MaxHeap candidate_set (search_allocator );
450452
451453 float valid_ratio = is_id_allowed ? is_id_allowed->ValidRatio () : 1 .0F ;
452454 float skip_threshold = valid_ratio == 1 .0F ? 0 : (1 - ((1 - valid_ratio) * skip_ratio));
@@ -1477,16 +1479,18 @@ HierarchicalNSW::searchKnn(const void* query_data,
14771479 uint64_t ef,
14781480 const vsag::FilterPtr is_id_allowed,
14791481 const float skip_ratio,
1482+ vsag::Allocator *allocator,
14801483 vsag::IteratorFilterContext* iter_ctx,
14811484 bool is_last_filter) const {
14821485 std::shared_lock resize_lock (resize_mutex_);
14831486 std::priority_queue<std::pair<float , LabelType>> result;
14841487 if (cur_element_count_ == 0 )
14851488 return result;
14861489
1490+ vsag::Allocator *search_alloctor = allocator == nullptr ? allocator_ : allocator;
14871491 std::shared_ptr<float []> normalize_query;
14881492 normalizeVector (query_data, normalize_query);
1489- MaxHeap top_candidates (allocator_ );
1493+ MaxHeap top_candidates (search_alloctor );
14901494 if (iter_ctx != nullptr && !iter_ctx->IsFirstUsed ()) {
14911495 if (iter_ctx->Empty ())
14921496 return result;
@@ -1504,6 +1508,7 @@ HierarchicalNSW::searchKnn(const void* query_data,
15041508 std::max (ef, k),
15051509 is_id_allowed,
15061510 skip_ratio,
1511+ allocator,
15071512 iter_ctx);
15081513 } else {
15091514 int64_t currObj;
@@ -1545,10 +1550,10 @@ HierarchicalNSW::searchKnn(const void* query_data,
15451550
15461551 if (num_deleted_ == 0 ) {
15471552 top_candidates = searchBaseLayerST<false , true >(
1548- currObj, query_data, std::max (ef, k), is_id_allowed, skip_ratio, iter_ctx);
1553+ currObj, query_data, std::max (ef, k), is_id_allowed, skip_ratio, allocator, iter_ctx);
15491554 } else {
15501555 top_candidates = searchBaseLayerST<true , true >(
1551- currObj, query_data, std::max (ef, k), is_id_allowed, skip_ratio, iter_ctx);
1556+ currObj, query_data, std::max (ef, k), is_id_allowed, skip_ratio, allocator, iter_ctx);
15521557 }
15531558 }
15541559
@@ -1669,6 +1674,7 @@ HierarchicalNSW::searchBaseLayerST<false, false>(
16691674 size_t ef,
16701675 const vsag::FilterPtr is_id_allowed,
16711676 const float skip_ratio,
1677+ vsag::Allocator *allocator,
16721678 vsag::IteratorFilterContext* iter_ctx = nullptr ) const ;
16731679
16741680template MaxHeap
0 commit comments