1111#include < absl/strings/str_split.h>
1212
1313#include < boost/iterator/function_output_iterator.hpp>
14+ #include < shared_mutex>
15+
16+ #include " core/search/base.h"
17+ #include " core/search/vector_utils.h"
18+ #include " util/fibers/synchronization.h"
1419
1520#define UNI_ALGO_DISABLE_NFKC_NFKD
1621
@@ -492,58 +497,116 @@ bool BaseVectorIndex<T>::Add(T id, const DocumentAccessor& doc, std::string_view
492497 return true ;
493498}
494499
495- template <typename T>
496- FlatVectorIndex<T>::FlatVectorIndex(const SchemaField::VectorParams& params,
497- PMR_NS::memory_resource* mr)
498- : BaseVectorIndex<T>{params.dim , params.sim }, entries_{mr} {
499- DCHECK (!params.use_hnsw );
500- entries_.reserve (params.capacity * params.dim );
500+ ShardNoOpVectorIndex::ShardNoOpVectorIndex (const SchemaField::VectorParams& params)
501+ : BaseVectorIndex<DocId>{params.dim , params.sim } {
501502}
502503
503- template <typename T>
504- void FlatVectorIndex<T>::AddVector(T id, const typename BaseVectorIndex<T>::VectorPtr& vector) {
505- DCHECK_LE (id * BaseVectorIndex<T>::dim_, entries_.size ());
506- if (id * BaseVectorIndex<T>::dim_ == entries_.size ())
507- entries_.resize ((id + 1 ) * BaseVectorIndex<T>::dim_);
504+ FlatVectorIndex::FlatVectorIndex (const SchemaField::VectorParams& params, ShardId shard_set_size,
505+ PMR_NS::memory_resource* mr)
506+ : BaseVectorIndex<GlobalDocId>{params.dim , params.sim },
507+ entries_{mr},
508+ shard_vector_locks_ (shard_set_size) {
509+ DCHECK (!params.use_hnsw );
510+ entries_.resize (shard_set_size);
511+ for (size_t i = 0 ; i < shard_set_size; i++) {
512+ entries_[i].resize (params.capacity * params.dim );
513+ }
514+ }
508515
509- // TODO: Let get vector write to buf itself
516+ void FlatVectorIndex::AddVector (GlobalDocId id,
517+ const typename BaseVectorIndex<GlobalDocId>::VectorPtr& vector) {
518+ auto shard_id = search::GlobalDocIdShardId (id);
519+ auto shard_doc_id = search::GlobalDocIdLocalId (id);
520+ if (shard_doc_id * BaseVectorIndex<GlobalDocId>::dim_ == entries_[shard_id].size ()) {
521+ unique_lock<util::fb2::SharedMutex> lock{shard_vector_locks_[shard_id]};
522+ entries_[shard_id].resize ((shard_doc_id + 1 ) * BaseVectorIndex<GlobalDocId>::dim_);
523+ }
510524 if (vector) {
511- memcpy (&entries_[id * BaseVectorIndex<T >::dim_], vector.get (),
512- BaseVectorIndex<T >::dim_ * sizeof (float ));
525+ memcpy (&entries_[shard_id][shard_doc_id * BaseVectorIndex<GlobalDocId >::dim_], vector.get (),
526+ BaseVectorIndex<GlobalDocId >::dim_ * sizeof (float ));
513527 }
514528}
515529
516- template <typename T>
517- void FlatVectorIndex<T>::Remove(T id, const DocumentAccessor& doc, string_view field) {
530+ void FlatVectorIndex::Remove (GlobalDocId id, const DocumentAccessor& doc, string_view field) {
518531 // noop
519532}
520533
521- template <typename T> const float * FlatVectorIndex<T>::Get(T doc) const {
522- return &entries_[doc * dim_];
534+ const float * FlatVectorIndex::Get (GlobalDocId doc) const {
535+ ShardId shard_id = search::GlobalDocIdShardId (doc);
536+ shared_lock<util::fb2::SharedMutex> lock{shard_vector_locks_[shard_id]};
537+ return &entries_[shard_id][search::GlobalDocIdLocalId (doc) * dim_];
538+ }
539+
540+ std::vector<std::pair<float , GlobalDocId>> FlatVectorIndex::Knn (float * target) const {
541+ std::priority_queue<std::pair<float , search::GlobalDocId>> queue;
542+
543+ for (size_t shard_id = 0 ; shard_id < entries_.size (); shard_id++) {
544+ shared_lock<util::fb2::SharedMutex> lock{shard_vector_locks_[shard_id]};
545+ size_t num_vectors = entries_[shard_id].size () / BaseVectorIndex<GlobalDocId>::dim_;
546+ for (GlobalDocId id = 0 ; id < num_vectors; ++id) {
547+ // Check if the vector is not zero (all elements are 0)
548+ // TODO: Valid vector can contain 0s, we should use a better approach
549+ const float * vec = &entries_[shard_id][id * dim_];
550+ float dist = VectorDistance (target, vec, dim_, sim_);
551+ queue.emplace (dist, CreateGlobalDocId (shard_id, id));
552+ }
553+ }
554+
555+ vector<pair<float , search::GlobalDocId>> out (queue.size ());
556+ size_t idx = out.size ();
557+ while (!queue.empty ()) {
558+ out[--idx] = queue.top ();
559+ queue.pop ();
560+ }
561+
562+ return out;
523563}
524564
525- template <typename T> std::vector<T> FlatVectorIndex<T>::GetAllDocsWithNonNullValues() const {
526- std::vector<T> result;
565+ std::vector<std::pair<float , GlobalDocId>> FlatVectorIndex::Knn (
566+ float * target, const std::vector<GlobalDocId>& allowed) const {
567+ std::priority_queue<std::pair<float , search::GlobalDocId>> queue;
568+
569+ for (auto & doc : allowed) {
570+ uint16_t shard_id = search::GlobalDocIdShardId (doc);
571+ auto shard_doc_id = search::GlobalDocIdLocalId (doc);
572+ shared_lock<util::fb2::SharedMutex> lock{shard_vector_locks_[shard_id]};
573+ const float * vec = &entries_[shard_id][shard_doc_id * dim_];
574+ float dist = VectorDistance (target, vec, dim_, sim_);
575+ queue.emplace (dist, doc);
576+ }
577+
578+ vector<pair<float , search::GlobalDocId>> out (queue.size ());
579+ size_t idx = out.size ();
580+ while (!queue.empty ()) {
581+ out[--idx] = queue.top ();
582+ queue.pop ();
583+ }
527584
528- size_t num_vectors = entries_. size () / BaseVectorIndex<T>::dim_ ;
529- result. reserve (num_vectors);
585+ return out ;
586+ }
530587
531- for (T id = 0 ; id < num_vectors; ++id) {
532- // Check if the vector is not zero (all elements are 0)
533- // TODO: Valid vector can contain 0s, we should use a better approach
534- const float * vec = Get (id);
535- bool is_zero_vector = true ;
588+ std::vector<GlobalDocId> FlatVectorIndex::GetAllDocsWithNonNullValues () const {
589+ std::vector<GlobalDocId> result;
590+ for (size_t shard_id = 0 ; shard_id < entries_.size (); shard_id++) {
591+ shared_lock<util::fb2::SharedMutex> lock{shard_vector_locks_[shard_id]};
592+ size_t num_vectors = entries_[shard_id].size () / BaseVectorIndex<GlobalDocId>::dim_;
593+ for (GlobalDocId id = 0 ; id < num_vectors; ++id) {
594+ // Check if the vector is not zero (all elements are 0)
595+ // TODO: Valid vector can contain 0s, we should use a better approach
596+ const float * vec = &entries_[shard_id][id * dim_];
597+ bool is_zero_vector = true ;
536598
537- // TODO: Consider don't use check for zero vector
538- for (size_t i = 0 ; i < BaseVectorIndex<T>::dim_; ++i) {
539- if (vec[i] != 0 .0f ) { // TODO: Consider using a threshold for float comparison
540- is_zero_vector = false ;
541- break ;
599+ // TODO: Consider don't use check for zero vector
600+ for (size_t i = 0 ; i < BaseVectorIndex<GlobalDocId>::dim_; ++i) {
601+ if (vec[i] != 0 .0f ) { // TODO: Consider using a threshold for float comparison
602+ is_zero_vector = false ;
603+ break ;
604+ }
542605 }
543- }
544606
545- if (!is_zero_vector) {
546- result.push_back (id);
607+ if (!is_zero_vector) {
608+ result.push_back (CreateGlobalDocId (shard_id, id));
609+ }
547610 }
548611 }
549612
@@ -552,9 +615,6 @@ template <typename T> std::vector<T> FlatVectorIndex<T>::GetAllDocsWithNonNullVa
552615 return result;
553616}
554617
555- template struct FlatVectorIndex <DocId>;
556- template struct FlatVectorIndex <GlobalDocId>;
557-
558618template <typename T> struct HnswlibAdapter {
559619 // Default setting of hnswlib/hnswalg
560620 constexpr static size_t kDefaultEfRuntime = 10 ;
@@ -662,44 +722,37 @@ template <typename T> struct HnswlibAdapter {
662722 absl::Mutex resize_mutex_;
663723};
664724
665- template <typename T>
666- HnswVectorIndex<T>::HnswVectorIndex(const SchemaField::VectorParams& params,
667- PMR_NS::memory_resource*)
668- : BaseVectorIndex<T>{params.dim , params.sim }, adapter_{make_unique<HnswlibAdapter<T>>(params)} {
725+ HnswVectorIndex::HnswVectorIndex (const SchemaField::VectorParams& params, PMR_NS::memory_resource*)
726+ : BaseVectorIndex<GlobalDocId>{params.dim , params.sim },
727+ adapter_{make_unique<HnswlibAdapter<GlobalDocId>>(params)} {
669728 DCHECK (params.use_hnsw );
670729 // TODO: Patch hnsw to use MR
671730}
672- template < typename T> HnswVectorIndex<T> ::~HnswVectorIndex () {
731+ HnswVectorIndex::~HnswVectorIndex () {
673732}
674733
675- template < typename T>
676- void HnswVectorIndex<T>::AddVector(T id, const typename BaseVectorIndex<T >::VectorPtr& vector) {
734+ void HnswVectorIndex::AddVector (GlobalDocId id,
735+ const typename BaseVectorIndex<GlobalDocId >::VectorPtr& vector) {
677736 if (vector) {
678737 adapter_->Add (vector.get (), id);
679738 }
680739}
681740
682- template <typename T>
683- std::vector<std::pair<float , T>> HnswVectorIndex<T>::Knn(float * target, size_t k,
684- std::optional<size_t > ef) const {
741+ std::vector<std::pair<float , GlobalDocId>> HnswVectorIndex::Knn (float * target, size_t k,
742+ std::optional<size_t > ef) const {
685743 return adapter_->Knn (target, k, ef);
686744}
687745
688- template <typename T>
689- std::vector<std::pair<float , T>> HnswVectorIndex<T>::Knn(float * target, size_t k,
690- std::optional<size_t > ef,
691- const std::vector<T>& allowed) const {
746+ std::vector<std::pair<float , GlobalDocId>> HnswVectorIndex::Knn (
747+ float * target, size_t k, std::optional<size_t > ef,
748+ const std::vector<GlobalDocId>& allowed) const {
692749 return adapter_->Knn (target, k, ef, allowed);
693750}
694751
695- template <typename T>
696- void HnswVectorIndex<T>::Remove(T id, const DocumentAccessor& doc, string_view field) {
752+ void HnswVectorIndex::Remove (GlobalDocId id, const DocumentAccessor& doc, string_view field) {
697753 adapter_->Remove (id);
698754}
699755
700- template struct HnswVectorIndex <DocId>;
701- template struct HnswVectorIndex <GlobalDocId>;
702-
703756GeoIndex::GeoIndex (PMR_NS::memory_resource* mr) : rtree_(make_unique<rtree>()) {
704757}
705758
0 commit comments