@@ -523,12 +523,12 @@ template <typename T> const float* FlatVectorIndex<T>::Get(T doc) const {
523523}
524524
525525template <typename T> std::vector<T> FlatVectorIndex<T>::GetAllDocsWithNonNullValues() const {
526- std::vector<DocId > result;
526+ std::vector<T > result;
527527
528528 size_t num_vectors = entries_.size () / BaseVectorIndex<T>::dim_;
529529 result.reserve (num_vectors);
530530
531- for (DocId id = 0 ; id < num_vectors; ++id) {
531+ for (T id = 0 ; id < num_vectors; ++id) {
532532 // Check if the vector is not zero (all elements are 0)
533533 // TODO: Valid vector can contain 0s, we should use a better approach
534534 const float * vec = Get (id);
@@ -553,8 +553,9 @@ template <typename T> std::vector<T> FlatVectorIndex<T>::GetAllDocsWithNonNullVa
553553}
554554
555555template struct FlatVectorIndex <DocId>;
556+ template struct FlatVectorIndex <GlobalDocId>;
556557
557- struct HnswlibAdapter {
558+ template < typename T> struct HnswlibAdapter {
558559 // Default setting of hnswlib/hnswalg
559560 constexpr static size_t kDefaultEfRuntime = 10 ;
560561
@@ -564,34 +565,45 @@ struct HnswlibAdapter {
564565 100 /* seed*/ } {
565566 }
566567
567- void Add (const float * data, DocId id) {
568- if (world_.cur_element_count + 1 >= world_.max_elements_ )
569- world_.resizeIndex (world_.cur_element_count * 2 );
570- world_.addPoint (data, id);
568+ void Add (const float * data, T id) {
569+ while (true ) {
570+ try {
571+ absl::ReaderMutexLock lock (&resize_mutex_);
572+ world_.addPoint (data, id);
573+ return ;
574+ } catch (const std::exception& e) {
575+ std::string error_msg = e.what ();
576+ if (absl::StrContains (error_msg, " The number of elements exceeds the specified limit" )) {
577+ ResizeIfFull ();
578+ continue ;
579+ }
580+ throw e;
581+ }
582+ }
571583 }
572584
573- void Remove (DocId id) {
585+ void Remove (T id) {
574586 try {
575587 world_.markDelete (id);
576588 } catch (const std::exception& e) {
577589 }
578590 }
579591
580- vector<pair<float , DocId >> Knn (float * target, size_t k, std::optional<size_t > ef) {
592+ vector<pair<float , T >> Knn (float * target, size_t k, std::optional<size_t > ef) {
581593 world_.setEf (ef.value_or (kDefaultEfRuntime ));
582594 return QueueToVec (world_.searchKnn (target, k));
583595 }
584596
585- vector<pair<float , DocId >> Knn (float * target, size_t k, std::optional<size_t > ef,
586- const vector<DocId >& allowed) {
597+ vector<pair<float , T >> Knn (float * target, size_t k, std::optional<size_t > ef,
598+ const vector<T >& allowed) {
587599 struct BinsearchFilter : hnswlib::BaseFilterFunctor {
588600 virtual bool operator ()(hnswlib::labeltype id) {
589601 return binary_search (allowed->begin (), allowed->end (), id);
590602 }
591603
592- BinsearchFilter (const vector<DocId >* allowed) : allowed{allowed} {
604+ BinsearchFilter (const vector<T >* allowed) : allowed{allowed} {
593605 }
594- const vector<DocId >* allowed;
606+ const vector<T >* allowed;
595607 };
596608
597609 world_.setEf (ef.value_or (kDefaultEfRuntime ));
@@ -613,8 +625,8 @@ struct HnswlibAdapter {
613625 return visit ([](auto & space) -> hnswlib::SpaceInterface<float >* { return &space; }, space_);
614626 }
615627
616- template <typename Q> static vector<pair<float , DocId >> QueueToVec (Q queue) {
617- vector<pair<float , DocId >> out (queue.size ());
628+ template <typename Q> static vector<pair<float , T >> QueueToVec (Q queue) {
629+ vector<pair<float , T >> out (queue.size ());
618630 size_t idx = out.size ();
619631 while (!queue.empty ()) {
620632 out[--idx] = queue.top ();
@@ -623,14 +635,37 @@ struct HnswlibAdapter {
623635 return out;
624636 }
625637
638+ void ResizeIfFull () {
639+ {
640+ absl::ReaderMutexLock lock (&resize_mutex_);
641+ if (world_.getCurrentElementCount () < world_.getMaxElements () ||
642+ (world_.allow_replace_deleted_ && world_.getDeletedCount () > 0 )) {
643+ return ;
644+ }
645+ }
646+ try {
647+ absl::WriterMutexLock lock (&resize_mutex_);
648+ if (world_.getCurrentElementCount () == world_.getMaxElements () &&
649+ (!world_.allow_replace_deleted_ || world_.getDeletedCount () == 0 )) {
650+ auto max_elements = world_.getMaxElements ();
651+ world_.resizeIndex (max_elements * 2 );
652+ LOG (INFO) << " Resizing HNSW Index, current size: " << max_elements
653+ << " , expand by: " << max_elements * 2 ;
654+ }
655+ } catch (const std::exception& e) {
656+ throw e;
657+ }
658+ }
659+
626660 SpaceUnion space_;
627661 hnswlib::HierarchicalNSW<float > world_;
662+ absl::Mutex resize_mutex_;
628663};
629664
630665template <typename T>
631666HnswVectorIndex<T>::HnswVectorIndex(const SchemaField::VectorParams& params,
632667 PMR_NS::memory_resource*)
633- : BaseVectorIndex<T>{params.dim , params.sim }, adapter_{make_unique<HnswlibAdapter>(params)} {
668+ : BaseVectorIndex<T>{params.dim , params.sim }, adapter_{make_unique<HnswlibAdapter<T> >(params)} {
634669 DCHECK (params.use_hnsw );
635670 // TODO: Patch hnsw to use MR
636671}
@@ -663,6 +698,7 @@ void HnswVectorIndex<T>::Remove(T id, const DocumentAccessor& doc, string_view f
663698}
664699
665700template struct HnswVectorIndex <DocId>;
701+ template struct HnswVectorIndex <GlobalDocId>;
666702
667703GeoIndex::GeoIndex (PMR_NS::memory_resource* mr) : rtree_(make_unique<rtree>()) {
668704}
0 commit comments