@@ -61,6 +61,8 @@ namespace hnswlib {
6161 maxlevel_ = -1 ;
6262
6363 linkLists_ = (char **) malloc (sizeof (void *) * max_elements_);
64+ if (linkLists_ == nullptr )
65+ throw std::runtime_error (" Not enough memory: HierarchicalNSW failed to allocate linklists" );
6466 size_links_per_element_ = maxM_ * sizeof (tableint) + sizeof (linklistsizeint);
6567 mult_ = 1 / log (1.0 * M_);
6668 revSize_ = 1.0 / mult_;
@@ -150,7 +152,7 @@ namespace hnswlib {
150152 }
151153
152154 std::priority_queue<std::pair<dist_t , tableint>, std::vector<std::pair<dist_t , tableint>>, CompareByFirst>
153- searchBaseLayer (tableint ep_id, void *data_point, int layer) {
155+ searchBaseLayer (tableint ep_id, const void *data_point, int layer) {
154156 VisitedList *vl = visited_list_pool_->getFreeVisitedList ();
155157 vl_type *visited_array = vl->mass ;
156158 vl_type visited_array_tag = vl->curV ;
@@ -371,7 +373,7 @@ namespace hnswlib {
371373 return (linklistsizeint *) (linkLists_[internal_id] + (level - 1 ) * size_links_per_element_);
372374 };
373375
374- void mutuallyConnectNewElement (void *data_point, tableint cur_c,
376+ void mutuallyConnectNewElement (const void *data_point, tableint cur_c,
375377 std::priority_queue<std::pair<dist_t , tableint>, std::vector<std::pair<dist_t , tableint>>, CompareByFirst> top_candidates,
376378 int level) {
377379
@@ -484,6 +486,8 @@ namespace hnswlib {
484486
485487
486488 std::priority_queue<std::pair<dist_t , tableint>> searchKnnInternal (void *query_data, int k) {
489+ std::priority_queue<std::pair<dist_t , tableint >> top_candidates;
490+ if (cur_element_count == 0 ) return top_candidates;
487491 tableint currObj = enterpoint_node_;
488492 dist_t curdist = fstdistfunc_ (query_data, getDataByInternalId (enterpoint_node_), dist_func_param_);
489493
@@ -510,8 +514,6 @@ namespace hnswlib {
510514 }
511515 }
512516
513-
514- std::priority_queue<std::pair<dist_t , tableint >> top_candidates;
515517 if (has_deletions_) {
516518 std::priority_queue<std::pair<dist_t , tableint >> top_candidates1=searchBaseLayerST<true >(currObj, query_data,
517519 ef_);
@@ -546,12 +548,16 @@ namespace hnswlib {
546548
547549 // Reallocate base layer
548550 char * data_level0_memory_new = (char *) malloc (new_max_elements * size_data_per_element_);
551+ if (data_level0_memory_new == nullptr )
552+ throw std::runtime_error (" Not enough memory: resizeIndex failed to allocate base layer" );
549553 memcpy (data_level0_memory_new, data_level0_memory_,cur_element_count * size_data_per_element_);
550554 free (data_level0_memory_);
551555 data_level0_memory_=data_level0_memory_new;
552556
553557 // Reallocate all other layers
554558 char ** linkLists_new = (char **) malloc (sizeof (void *) * new_max_elements);
559+ if (linkLists_new == nullptr )
560+ throw std::runtime_error (" Not enough memory: resizeIndex failed to allocate other layers" );
555561 memcpy (linkLists_new, linkLists_,cur_element_count * sizeof (void *));
556562 free (linkLists_);
557563 linkLists_=linkLists_new;
@@ -659,6 +665,8 @@ namespace hnswlib {
659665
660666
661667 data_level0_memory_ = (char *) malloc (max_elements * size_data_per_element_);
668+ if (data_level0_memory_ == nullptr )
669+ throw std::runtime_error (" Not enough memory: loadIndex failed to allocate level0" );
662670 input.read (data_level0_memory_, cur_element_count * size_data_per_element_);
663671
664672
@@ -675,6 +683,8 @@ namespace hnswlib {
675683
676684
677685 linkLists_ = (char **) malloc (sizeof (void *) * max_elements);
686+ if (linkLists_ == nullptr )
687+ throw std::runtime_error (" Not enough memory: loadIndex failed to allocate linklists" );
678688 element_levels_ = std::vector<int >(max_elements);
679689 revSize_ = 1.0 / mult_;
680690 ef_ = 10 ;
@@ -689,6 +699,8 @@ namespace hnswlib {
689699 } else {
690700 element_levels_[i] = linkListSize / size_links_per_element_;
691701 linkLists_[i] = (char *) malloc (linkListSize);
702+ if (linkLists_[i] == nullptr )
703+ throw std::runtime_error (" Not enough memory: loadIndex failed to allocate linklist" );
692704 input.read (linkLists_[i], linkListSize);
693705 }
694706 }
@@ -779,11 +791,11 @@ namespace hnswlib {
779791 *((unsigned short int *)(ptr))=*((unsigned short int *)&size);
780792 }
781793
782- void addPoint (void *data_point, labeltype label) {
794+ void addPoint (const void *data_point, labeltype label) {
783795 addPoint (data_point, label,-1 );
784796 }
785797
786- tableint addPoint (void *data_point, labeltype label, int level) {
798+ tableint addPoint (const void *data_point, labeltype label, int level) {
787799 tableint cur_c = 0 ;
788800 {
789801 std::unique_lock <std::mutex> lock (cur_element_count_guard_);
@@ -797,6 +809,7 @@ namespace hnswlib {
797809 auto search = label_lookup_.find (label);
798810 if (search != label_lookup_.end ()) {
799811 std::unique_lock <std::mutex> lock_el (link_list_locks_[search->second ]);
812+ has_deletions_ = true ;
800813 markDeletedInternal (search->second );
801814 }
802815 label_lookup_[label] = cur_c;
@@ -827,6 +840,8 @@ namespace hnswlib {
827840
828841 if (curlevel) {
829842 linkLists_[cur_c] = (char *) malloc (size_links_per_element_ * curlevel + 1 );
843+ if (linkLists_[cur_c] == nullptr )
844+ throw std::runtime_error (" Not enough memory: addPoint failed to allocate linklist" );
830845 memset (linkLists_[cur_c], 0 , size_links_per_element_ * curlevel + 1 );
831846 }
832847
@@ -895,7 +910,11 @@ namespace hnswlib {
895910 return cur_c;
896911 };
897912
898- std::priority_queue<std::pair<dist_t , labeltype >> searchKnn (const void *query_data, size_t k) const {
913+ std::priority_queue<std::pair<dist_t , labeltype >>
914+ searchKnn (const void *query_data, size_t k) const {
915+ std::priority_queue<std::pair<dist_t , labeltype >> result;
916+ if (cur_element_count == 0 ) return result;
917+
899918 tableint currObj = enterpoint_node_;
900919 dist_t curdist = fstdistfunc_ (query_data, getDataByInternalId (enterpoint_node_), dist_func_param_);
901920
@@ -934,18 +953,34 @@ namespace hnswlib {
934953 currObj, query_data, std::max (ef_, k));
935954 top_candidates.swap (top_candidates1);
936955 }
937- std::priority_queue<std::pair<dist_t , labeltype >> results;
938956 while (top_candidates.size () > k) {
939957 top_candidates.pop ();
940958 }
941959 while (top_candidates.size () > 0 ) {
942960 std::pair<dist_t , tableint> rez = top_candidates.top ();
943- results .push (std::pair<dist_t , labeltype>(rez.first , getExternalLabel (rez.second )));
961+ result .push (std::pair<dist_t , labeltype>(rez.first , getExternalLabel (rez.second )));
944962 top_candidates.pop ();
945963 }
946- return results ;
964+ return result ;
947965 };
948966
967+ template <typename Comp>
968+ std::vector<std::pair<dist_t , labeltype>>
969+ searchKnn (const void * query_data, size_t k, Comp comp) {
970+ std::vector<std::pair<dist_t , labeltype>> result;
971+ if (cur_element_count == 0 ) return result;
972+
973+ auto ret = searchKnn (query_data, k);
974+
975+ while (!ret.empty ()) {
976+ result.push_back (ret.top ());
977+ ret.pop ();
978+ }
979+
980+ std::sort (result.begin (), result.end (), comp);
981+
982+ return result;
983+ }
949984
950985 };
951986
0 commit comments