1111#include < absl/strings/str_split.h>
1212
1313#include < boost/iterator/function_output_iterator.hpp>
14+ #include < shared_mutex>
15+
16+ #include " core/search/base.h"
17+ #include " core/search/vector_utils.h"
18+ #include " util/fibers/synchronization.h"
1419
1520#define UNI_ALGO_DISABLE_NFKC_NFKD
1621
@@ -128,6 +133,16 @@ std::optional<GeoIndex::point> GetGeoPoint(const DocumentAccessor& doc, string_v
128133 return GeoIndex::point{lon, lat};
129134}
130135
136+ template <typename Q, typename T = GlobalDocId> vector<pair<float , T>> QueueToVec (Q queue) {
137+ vector<pair<float , T>> out (queue.size ());
138+ size_t idx = out.size ();
139+ while (!queue.empty ()) {
140+ out[--idx] = queue.top ();
141+ queue.pop ();
142+ }
143+ return out;
144+ }
145+
131146}; // namespace
132147
133148class RangeTreeAdapter : public NumericIndex ::RangeTreeBase {
@@ -492,69 +507,73 @@ bool BaseVectorIndex<T>::Add(T id, const DocumentAccessor& doc, std::string_view
492507 return true ;
493508}
494509
495- template <typename T>
496- FlatVectorIndex<T>::FlatVectorIndex(const SchemaField::VectorParams& params,
497- PMR_NS::memory_resource* mr)
498- : BaseVectorIndex<T>{params.dim , params.sim }, entries_{mr} {
499- DCHECK (!params.use_hnsw );
500- entries_.reserve (params.capacity * params.dim );
510+ ShardNoOpVectorIndex::ShardNoOpVectorIndex (const SchemaField::VectorParams& params)
511+ : BaseVectorIndex<DocId>{params.dim , params.sim } {
501512}
502513
503- template <typename T>
504- void FlatVectorIndex<T>::AddVector(T id, const typename BaseVectorIndex<T>::VectorPtr& vector) {
505- DCHECK_LE (id * BaseVectorIndex<T>::dim_, entries_.size ());
506- if (id * BaseVectorIndex<T>::dim_ == entries_.size ())
507- entries_.resize ((id + 1 ) * BaseVectorIndex<T>::dim_);
514+ FlatVectorIndex::FlatVectorIndex (const SchemaField::VectorParams& params, ShardId shard_set_size,
515+ PMR_NS::memory_resource* mr)
516+ : BaseVectorIndex<GlobalDocId>{params.dim , params.sim },
517+ entries_{mr},
518+ shard_vector_locks_ (shard_set_size) {
519+ DCHECK (!params.use_hnsw );
520+ entries_.resize (shard_set_size);
521+ for (size_t i = 0 ; i < shard_set_size; i++) {
522+ entries_[i].reserve (params.capacity * params.dim );
523+ }
524+ }
508525
509- // TODO: Let get vector write to buf itself
526+ void FlatVectorIndex::AddVector (GlobalDocId id,
527+ const typename BaseVectorIndex<GlobalDocId>::VectorPtr& vector) {
528+ auto shard_id = search::GlobalDocIdShardId (id);
529+ auto shard_doc_id = search::GlobalDocIdLocalId (id);
530+ DCHECK_LE (shard_doc_id * BaseVectorIndex<GlobalDocId>::dim_, entries_[shard_id].size ());
531+ if (shard_doc_id * BaseVectorIndex<GlobalDocId>::dim_ == entries_[shard_id].size ()) {
532+ unique_lock<util::fb2::SharedMutex> lock{shard_vector_locks_[shard_id]};
533+ entries_[shard_id].resize ((shard_doc_id + 1 ) * BaseVectorIndex<GlobalDocId>::dim_);
534+ }
510535 if (vector) {
511- memcpy (&entries_[id * BaseVectorIndex<T >::dim_], vector.get (),
512- BaseVectorIndex<T >::dim_ * sizeof (float ));
536+ memcpy (&entries_[shard_id][shard_doc_id * BaseVectorIndex<GlobalDocId >::dim_], vector.get (),
537+ BaseVectorIndex<GlobalDocId >::dim_ * sizeof (float ));
513538 }
514539}
515540
516- template <typename T>
517- void FlatVectorIndex<T>::Remove(T id, const DocumentAccessor& doc, string_view field) {
541+ void FlatVectorIndex::Remove (GlobalDocId id, const DocumentAccessor& doc, string_view field) {
518542 // noop
519543}
520544
521- template <typename T> const float * FlatVectorIndex<T>::Get(T doc) const {
522- return &entries_[doc * dim_];
523- }
545+ std::vector<std::pair<float , GlobalDocId>> FlatVectorIndex::Knn (float * target, size_t k) const {
546+ std::priority_queue<std::pair<float , search::GlobalDocId>> queue;
524547
525- template <typename T> std::vector<T> FlatVectorIndex<T>::GetAllDocsWithNonNullValues() const {
526- std::vector<DocId> result;
548+ for (size_t shard_id = 0 ; shard_id < entries_.size (); shard_id++) {
549+ shared_lock<util::fb2::SharedMutex> lock{shard_vector_locks_[shard_id]};
550+ size_t num_vectors = entries_[shard_id].size () / BaseVectorIndex<GlobalDocId>::dim_;
551+ for (GlobalDocId id = 0 ; id < num_vectors; ++id) {
552+ const float * vec = &entries_[shard_id][id * dim_];
553+ float dist = VectorDistance (target, vec, dim_, sim_);
554+ queue.emplace (dist, CreateGlobalDocId (shard_id, id));
555+ }
556+ }
527557
528- size_t num_vectors = entries_. size () / BaseVectorIndex<T>::dim_ ;
529- result. reserve (num_vectors);
558+ return QueueToVec (queue) ;
559+ }
530560
531- for (DocId id = 0 ; id < num_vectors; ++id) {
532- // Check if the vector is not zero (all elements are 0)
533- // TODO: Valid vector can contain 0s, we should use a better approach
534- const float * vec = Get (id);
535- bool is_zero_vector = true ;
561+ std::vector<std::pair<float , GlobalDocId>> FlatVectorIndex::Knn (
562+ float * target, size_t k, const std::vector<FilterShardDocs>& allowed_docs) const {
563+ std::priority_queue<std::pair<float , search::GlobalDocId>> queue;
536564
537- // TODO: Consider don't use check for zero vector
538- for (size_t i = 0 ; i < BaseVectorIndex<T>::dim_; ++i) {
539- if (vec[i] != 0 .0f ) { // TODO: Consider using a threshold for float comparison
540- is_zero_vector = false ;
541- break ;
542- }
543- }
544-
545- if (!is_zero_vector) {
546- result.push_back (id);
565+ for (size_t shard_id = 0 ; shard_id < allowed_docs.size (); shard_id++) {
566+ shared_lock<util::fb2::SharedMutex> lock{shard_vector_locks_[shard_id]};
567+ for (auto & shard_doc_id : allowed_docs[shard_id]) {
568+ const float * vec = &entries_[shard_id][shard_doc_id * dim_];
569+ float dist = VectorDistance (target, vec, dim_, sim_);
570+ queue.emplace (dist, CreateGlobalDocId (shard_id, shard_doc_id));
547571 }
548572 }
549-
550- // Result is already sorted by id, no need to sort again
551- // Also it has no duplicates
552- return result;
573+ return QueueToVec (queue);
553574}
554575
555- template struct FlatVectorIndex <DocId>;
556-
557- struct HnswlibAdapter {
576+ template <typename T> struct HnswlibAdapter {
558577 // Default setting of hnswlib/hnswalg
559578 constexpr static size_t kDefaultEfRuntime = 10 ;
560579
@@ -564,34 +583,45 @@ struct HnswlibAdapter {
564583 100 /* seed*/ } {
565584 }
566585
567- void Add (const float * data, DocId id) {
568- if (world_.cur_element_count + 1 >= world_.max_elements_ )
569- world_.resizeIndex (world_.cur_element_count * 2 );
570- world_.addPoint (data, id);
586+ void Add (const float * data, T id) {
587+ while (true ) {
588+ try {
589+ absl::ReaderMutexLock lock (&resize_mutex_);
590+ world_.addPoint (data, id);
591+ return ;
592+ } catch (const std::exception& e) {
593+ std::string error_msg = e.what ();
594+ if (absl::StrContains (error_msg, " The number of elements exceeds the specified limit" )) {
595+ ResizeIfFull ();
596+ continue ;
597+ }
598+ throw e;
599+ }
600+ }
571601 }
572602
573- void Remove (DocId id) {
603+ void Remove (T id) {
574604 try {
575605 world_.markDelete (id);
576606 } catch (const std::exception& e) {
577607 }
578608 }
579609
580- vector<pair<float , DocId >> Knn (float * target, size_t k, std::optional<size_t > ef) {
610+ vector<pair<float , T >> Knn (float * target, size_t k, std::optional<size_t > ef) {
581611 world_.setEf (ef.value_or (kDefaultEfRuntime ));
582612 return QueueToVec (world_.searchKnn (target, k));
583613 }
584614
585- vector<pair<float , DocId >> Knn (float * target, size_t k, std::optional<size_t > ef,
586- const vector<DocId >& allowed) {
615+ vector<pair<float , T >> Knn (float * target, size_t k, std::optional<size_t > ef,
616+ const vector<T >& allowed) {
587617 struct BinsearchFilter : hnswlib::BaseFilterFunctor {
588618 virtual bool operator ()(hnswlib::labeltype id) {
589619 return binary_search (allowed->begin (), allowed->end (), id);
590620 }
591621
592- BinsearchFilter (const vector<DocId >* allowed) : allowed{allowed} {
622+ BinsearchFilter (const vector<T >* allowed) : allowed{allowed} {
593623 }
594- const vector<DocId >* allowed;
624+ const vector<T >* allowed;
595625 };
596626
597627 world_.setEf (ef.value_or (kDefaultEfRuntime ));
@@ -613,57 +643,64 @@ struct HnswlibAdapter {
613643 return visit ([](auto & space) -> hnswlib::SpaceInterface<float >* { return &space; }, space_);
614644 }
615645
616- template <typename Q> static vector<pair<float , DocId>> QueueToVec (Q queue) {
617- vector<pair<float , DocId>> out (queue.size ());
618- size_t idx = out.size ();
619- while (!queue.empty ()) {
620- out[--idx] = queue.top ();
621- queue.pop ();
646+ void ResizeIfFull () {
647+ {
648+ absl::ReaderMutexLock lock (&resize_mutex_);
649+ if (world_.getCurrentElementCount () < world_.getMaxElements () ||
650+ (world_.allow_replace_deleted_ && world_.getDeletedCount () > 0 )) {
651+ return ;
652+ }
653+ }
654+ try {
655+ absl::WriterMutexLock lock (&resize_mutex_);
656+ if (world_.getCurrentElementCount () == world_.getMaxElements () &&
657+ (!world_.allow_replace_deleted_ || world_.getDeletedCount () == 0 )) {
658+ auto max_elements = world_.getMaxElements ();
659+ world_.resizeIndex (max_elements * 2 );
660+ LOG (INFO) << " Resizing HNSW Index, current size: " << max_elements
661+ << " , expand by: " << max_elements * 2 ;
662+ }
663+ } catch (const std::exception& e) {
664+ throw e;
622665 }
623- return out;
624666 }
625667
626668 SpaceUnion space_;
627669 hnswlib::HierarchicalNSW<float > world_;
670+ absl::Mutex resize_mutex_;
628671};
629672
630- template <typename T>
631- HnswVectorIndex<T>::HnswVectorIndex(const SchemaField::VectorParams& params,
632- PMR_NS::memory_resource*)
633- : BaseVectorIndex<T>{params.dim , params.sim }, adapter_{make_unique<HnswlibAdapter>(params)} {
673+ HnswVectorIndex::HnswVectorIndex (const SchemaField::VectorParams& params, PMR_NS::memory_resource*)
674+ : BaseVectorIndex<GlobalDocId>{params.dim , params.sim },
675+ adapter_{make_unique<HnswlibAdapter<GlobalDocId>>(params)} {
634676 DCHECK (params.use_hnsw );
635677 // TODO: Patch hnsw to use MR
636678}
637- template < typename T> HnswVectorIndex<T> ::~HnswVectorIndex () {
679+ HnswVectorIndex::~HnswVectorIndex () {
638680}
639681
640- template < typename T>
641- void HnswVectorIndex<T>::AddVector(T id, const typename BaseVectorIndex<T >::VectorPtr& vector) {
682+ void HnswVectorIndex::AddVector (GlobalDocId id,
683+ const typename BaseVectorIndex<GlobalDocId >::VectorPtr& vector) {
642684 if (vector) {
643685 adapter_->Add (vector.get (), id);
644686 }
645687}
646688
647- template <typename T>
648- std::vector<std::pair<float , T>> HnswVectorIndex<T>::Knn(float * target, size_t k,
649- std::optional<size_t > ef) const {
689+ std::vector<std::pair<float , GlobalDocId>> HnswVectorIndex::Knn (float * target, size_t k,
690+ std::optional<size_t > ef) const {
650691 return adapter_->Knn (target, k, ef);
651692}
652693
653- template <typename T>
654- std::vector<std::pair<float , T>> HnswVectorIndex<T>::Knn(float * target, size_t k,
655- std::optional<size_t > ef,
656- const std::vector<T>& allowed) const {
694+ std::vector<std::pair<float , GlobalDocId>> HnswVectorIndex::Knn (
695+ float * target, size_t k, std::optional<size_t > ef,
696+ const std::vector<GlobalDocId>& allowed) const {
657697 return adapter_->Knn (target, k, ef, allowed);
658698}
659699
660- template <typename T>
661- void HnswVectorIndex<T>::Remove(T id, const DocumentAccessor& doc, string_view field) {
700+ void HnswVectorIndex::Remove (GlobalDocId id, const DocumentAccessor& doc, string_view field) {
662701 adapter_->Remove (id);
663702}
664703
665- template struct HnswVectorIndex <DocId>;
666-
667704GeoIndex::GeoIndex (PMR_NS::memory_resource* mr) : rtree_(make_unique<rtree>()) {
668705}
669706
0 commit comments