@@ -473,14 +473,12 @@ absl::flat_hash_set<std::string> TagIndex::Tokenize(std::string_view value) cons
473473 return NormalizeTags (value, case_sensitive_, separator_);
474474}
475475
476- BaseVectorIndex::BaseVectorIndex (size_t dim, VectorSimilarity sim) : dim_{dim}, sim_{sim} {
476+ template <typename T>
477+ BaseVectorIndex<T>::BaseVectorIndex(size_t dim, VectorSimilarity sim) : dim_{dim}, sim_{sim} {
477478}
478479
479- std::pair<size_t /* dim*/ , VectorSimilarity> BaseVectorIndex::Info () const {
480- return {dim_, sim_};
481- }
482-
483- bool BaseVectorIndex::Add (DocId id, const DocumentAccessor& doc, std::string_view field) {
480+ template <typename T>
481+ bool BaseVectorIndex<T>::Add(T id, const DocumentAccessor& doc, std::string_view field) {
484482 auto vector = doc.GetVector (field);
485483 if (!vector)
486484 return false ;
@@ -494,36 +492,40 @@ bool BaseVectorIndex::Add(DocId id, const DocumentAccessor& doc, std::string_vie
494492 return true ;
495493}
496494
497- FlatVectorIndex::FlatVectorIndex (const SchemaField::VectorParams& params,
498- PMR_NS::memory_resource* mr)
499- : BaseVectorIndex{params.dim , params.sim }, entries_{mr} {
495+ template <typename T>
496+ FlatVectorIndex<T>::FlatVectorIndex(const SchemaField::VectorParams& params,
497+ PMR_NS::memory_resource* mr)
498+ : BaseVectorIndex<T>{params.dim , params.sim }, entries_{mr} {
500499 DCHECK (!params.use_hnsw );
501500 entries_.reserve (params.capacity * params.dim );
502501}
503502
504- void FlatVectorIndex::AddVector (DocId id, const VectorPtr& vector) {
505- DCHECK_LE (id * dim_, entries_.size ());
506- if (id * dim_ == entries_.size ())
507- entries_.resize ((id + 1 ) * dim_);
503+ template <typename T>
504+ void FlatVectorIndex<T>::AddVector(T id, const typename BaseVectorIndex<T>::VectorPtr& vector) {
505+ DCHECK_LE (id * BaseVectorIndex<T>::dim_, entries_.size ());
506+ if (id * BaseVectorIndex<T>::dim_ == entries_.size ())
507+ entries_.resize ((id + 1 ) * BaseVectorIndex<T>::dim_);
508508
509509 // TODO: Let get vector write to buf itself
510510 if (vector) {
511- memcpy (&entries_[id * dim_], vector.get (), dim_ * sizeof (float ));
511+ memcpy (&entries_[id * BaseVectorIndex<T>::dim_], vector.get (),
512+ BaseVectorIndex<T>::dim_ * sizeof (float ));
512513 }
513514}
514515
515- void FlatVectorIndex::Remove (DocId id, const DocumentAccessor& doc, string_view field) {
516+ template <typename T>
517+ void FlatVectorIndex<T>::Remove(T id, const DocumentAccessor& doc, string_view field) {
516518 // noop
517519}
518520
519- const float * FlatVectorIndex::Get (DocId doc) const {
521+ template < typename T> const float * FlatVectorIndex<T> ::Get(T doc) const {
520522 return &entries_[doc * dim_];
521523}
522524
523- std::vector<DocId > FlatVectorIndex::GetAllDocsWithNonNullValues () const {
525+ template < typename T> std::vector<T > FlatVectorIndex<T> ::GetAllDocsWithNonNullValues() const {
524526 std::vector<DocId> result;
525527
526- size_t num_vectors = entries_.size () / dim_;
528+ size_t num_vectors = entries_.size () / BaseVectorIndex<T>:: dim_;
527529 result.reserve (num_vectors);
528530
529531 for (DocId id = 0 ; id < num_vectors; ++id) {
@@ -533,7 +535,7 @@ std::vector<DocId> FlatVectorIndex::GetAllDocsWithNonNullValues() const {
533535 bool is_zero_vector = true ;
534536
535537 // TODO: Consider don't use check for zero vector
536- for (size_t i = 0 ; i < dim_; ++i) {
538+ for (size_t i = 0 ; i < BaseVectorIndex<T>:: dim_; ++i) {
537539 if (vec[i] != 0 .0f ) { // TODO: Consider using a threshold for float comparison
538540 is_zero_vector = false ;
539541 break ;
@@ -550,6 +552,8 @@ std::vector<DocId> FlatVectorIndex::GetAllDocsWithNonNullValues() const {
550552 return result;
551553}
552554
555+ template struct FlatVectorIndex <DocId>;
556+
553557struct HnswlibAdapter {
554558 // Default setting of hnswlib/hnswalg
555559 constexpr static size_t kDefaultEfRuntime = 10 ;
@@ -623,35 +627,43 @@ struct HnswlibAdapter {
623627 hnswlib::HierarchicalNSW<float > world_;
624628};
625629
626- HnswVectorIndex::HnswVectorIndex (const SchemaField::VectorParams& params, PMR_NS::memory_resource*)
627- : BaseVectorIndex{params.dim , params.sim }, adapter_{make_unique<HnswlibAdapter>(params)} {
630+ template <typename T>
631+ HnswVectorIndex<T>::HnswVectorIndex(const SchemaField::VectorParams& params,
632+ PMR_NS::memory_resource*)
633+ : BaseVectorIndex<T>{params.dim , params.sim }, adapter_{make_unique<HnswlibAdapter>(params)} {
628634 DCHECK (params.use_hnsw );
629635 // TODO: Patch hnsw to use MR
630636}
631-
632- HnswVectorIndex::~HnswVectorIndex () {
637+ template <typename T> HnswVectorIndex<T>::~HnswVectorIndex () {
633638}
634639
635- void HnswVectorIndex::AddVector (DocId id, const VectorPtr& vector) {
640+ template <typename T>
641+ void HnswVectorIndex<T>::AddVector(T id, const typename BaseVectorIndex<T>::VectorPtr& vector) {
636642 if (vector) {
637643 adapter_->Add (vector.get (), id);
638644 }
639645}
640646
641- std::vector<std::pair<float , DocId>> HnswVectorIndex::Knn (float * target, size_t k,
642- std::optional<size_t > ef) const {
647+ template <typename T>
648+ std::vector<std::pair<float , T>> HnswVectorIndex<T>::Knn(float * target, size_t k,
649+ std::optional<size_t > ef) const {
643650 return adapter_->Knn (target, k, ef);
644651}
645- std::vector<std::pair<float , DocId>> HnswVectorIndex::Knn (float * target, size_t k,
646- std::optional<size_t > ef,
647- const std::vector<DocId>& allowed) const {
652+
653+ template <typename T>
654+ std::vector<std::pair<float , T>> HnswVectorIndex<T>::Knn(float * target, size_t k,
655+ std::optional<size_t > ef,
656+ const std::vector<T>& allowed) const {
648657 return adapter_->Knn (target, k, ef, allowed);
649658}
650659
651- void HnswVectorIndex::Remove (DocId id, const DocumentAccessor& doc, string_view field) {
660+ template <typename T>
661+ void HnswVectorIndex<T>::Remove(T id, const DocumentAccessor& doc, string_view field) {
652662 adapter_->Remove (id);
653663}
654664
665+ template struct HnswVectorIndex <DocId>;
666+
655667GeoIndex::GeoIndex (PMR_NS::memory_resource* mr) : rtree_(make_unique<rtree>()) {
656668}
657669
0 commit comments