@@ -150,7 +150,7 @@ namespace hnswlib {
150150 }
151151
152152 std::priority_queue<std::pair<dist_t , tableint>, std::vector<std::pair<dist_t , tableint>>, CompareByFirst>
153- searchBaseLayer (tableint ep_id, void *data_point, int layer) {
153+ searchBaseLayer (tableint ep_id, const void *data_point, int layer) {
154154 VisitedList *vl = visited_list_pool_->getFreeVisitedList ();
155155 vl_type *visited_array = vl->mass ;
156156 vl_type visited_array_tag = vl->curV ;
@@ -371,7 +371,7 @@ namespace hnswlib {
371371 return (linklistsizeint *) (linkLists_[internal_id] + (level - 1 ) * size_links_per_element_);
372372 };
373373
374- void mutuallyConnectNewElement (void *data_point, tableint cur_c,
374+ void mutuallyConnectNewElement (const void *data_point, tableint cur_c,
375375 std::priority_queue<std::pair<dist_t , tableint>, std::vector<std::pair<dist_t , tableint>>, CompareByFirst> top_candidates,
376376 int level) {
377377
@@ -484,6 +484,8 @@ namespace hnswlib {
484484
485485
486486 std::priority_queue<std::pair<dist_t , tableint>> searchKnnInternal (void *query_data, int k) {
487+ std::priority_queue<std::pair<dist_t , tableint >> top_candidates;
488+ if (cur_element_count == 0 ) return top_candidates;
487489 tableint currObj = enterpoint_node_;
488490 dist_t curdist = fstdistfunc_ (query_data, getDataByInternalId (enterpoint_node_), dist_func_param_);
489491
@@ -510,8 +512,6 @@ namespace hnswlib {
510512 }
511513 }
512514
513-
514- std::priority_queue<std::pair<dist_t , tableint >> top_candidates;
515515 if (has_deletions_) {
516516 std::priority_queue<std::pair<dist_t , tableint >> top_candidates1=searchBaseLayerST<true >(currObj, query_data,
517517 ef_);
@@ -779,11 +779,11 @@ namespace hnswlib {
779779 *((unsigned short int *)(ptr))=*((unsigned short int *)&size);
780780 }
781781
782- void addPoint (void *data_point, labeltype label) {
782+ void addPoint (const void *data_point, labeltype label) {
783783 addPoint (data_point, label,-1 );
784784 }
785785
786- tableint addPoint (void *data_point, labeltype label, int level) {
786+ tableint addPoint (const void *data_point, labeltype label, int level) {
787787 tableint cur_c = 0 ;
788788 {
789789 std::unique_lock <std::mutex> lock (cur_element_count_guard_);
@@ -895,7 +895,11 @@ namespace hnswlib {
895895 return cur_c;
896896 };
897897
898- std::priority_queue<std::pair<dist_t , labeltype >> searchKnn (const void *query_data, size_t k) const {
898+ std::priority_queue<std::pair<dist_t , labeltype >>
899+ searchKnn (const void *query_data, size_t k) const {
900+ std::priority_queue<std::pair<dist_t , labeltype >> result;
901+ if (cur_element_count == 0 ) return result;
902+
899903 tableint currObj = enterpoint_node_;
900904 dist_t curdist = fstdistfunc_ (query_data, getDataByInternalId (enterpoint_node_), dist_func_param_);
901905
@@ -934,18 +938,34 @@ namespace hnswlib {
934938 currObj, query_data, std::max (ef_, k));
935939 top_candidates.swap (top_candidates1);
936940 }
937- std::priority_queue<std::pair<dist_t , labeltype >> results;
938941 while (top_candidates.size () > k) {
939942 top_candidates.pop ();
940943 }
941944 while (top_candidates.size () > 0 ) {
942945 std::pair<dist_t , tableint> rez = top_candidates.top ();
943- results .push (std::pair<dist_t , labeltype>(rez.first , getExternalLabel (rez.second )));
946+ result .push (std::pair<dist_t , labeltype>(rez.first , getExternalLabel (rez.second )));
944947 top_candidates.pop ();
945948 }
946- return results ;
949+ return result ;
947950 };
948951
952+ template <typename Comp>
953+ std::vector<std::pair<dist_t , labeltype>>
954+ searchKnn (const void * query_data, size_t k, Comp comp) {
955+ std::vector<std::pair<dist_t , labeltype>> result;
956+ if (cur_element_count == 0 ) return result;
957+
958+ auto ret = searchKnn (query_data, k);
959+
960+ while (!ret.empty ()) {
961+ result.push_back (ret.top ());
962+ ret.pop ();
963+ }
964+
965+ std::sort (result.begin (), result.end (), comp);
966+
967+ return result;
968+ }
949969
950970 };
951971
0 commit comments