Skip to content

Commit 6480e41

Browse files
committed
Templated BaseIndex class
1 parent 642f2ea commit 6480e41

File tree

7 files changed

+100
-80
lines changed

7 files changed

+100
-80
lines changed

src/core/search/base.h

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,16 @@ struct DocumentAccessor {
7979
//
8080
// Queries should be done directly on subclasses with their distinc
8181
// query functions. All results for all index types should be sorted.
82-
struct BaseIndex {
82+
template <typename T> struct BaseIndex {
8383
virtual ~BaseIndex() = default;
8484

8585
// Returns true if the document was added / indexed
86-
virtual bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) = 0;
87-
virtual void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) = 0;
86+
virtual bool Add(T id, const DocumentAccessor& doc, std::string_view field) = 0;
87+
virtual void Remove(T id, const DocumentAccessor& doc, std::string_view field) = 0;
8888

8989
// Returns documents that have non-null values for this field (used for @field:* queries)
9090
// Result must be sorted
91-
virtual std::vector<DocId> GetAllDocsWithNonNullValues() const = 0;
91+
virtual std::vector<T> GetAllDocsWithNonNullValues() const = 0;
9292

9393
/* Called at the end of indexes rebuilding after all initial Add calls are done.
9494
Some indices may need to finalize internal structures. See RangeTree for example. */
@@ -97,10 +97,9 @@ struct BaseIndex {
9797
};
9898

9999
// Base class for type-specific sorting indices.
100-
struct BaseSortIndex : BaseIndex {
101-
virtual SortableValue Lookup(DocId doc) const = 0;
102-
virtual std::vector<SortableValue> Sort(std::vector<DocId>* ids, size_t limit,
103-
bool desc) const = 0;
100+
template <typename T> struct BaseSortIndex : BaseIndex<T> {
101+
virtual SortableValue Lookup(T doc) const = 0;
102+
virtual std::vector<SortableValue> Sort(std::vector<T>* ids, size_t limit, bool desc) const = 0;
104103
};
105104

106105
/* Used in iterators of inverse indices.

src/core/search/indices.cc

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -473,14 +473,12 @@ absl::flat_hash_set<std::string> TagIndex::Tokenize(std::string_view value) cons
473473
return NormalizeTags(value, case_sensitive_, separator_);
474474
}
475475

476-
BaseVectorIndex::BaseVectorIndex(size_t dim, VectorSimilarity sim) : dim_{dim}, sim_{sim} {
476+
template <typename T>
477+
BaseVectorIndex<T>::BaseVectorIndex(size_t dim, VectorSimilarity sim) : dim_{dim}, sim_{sim} {
477478
}
478479

479-
std::pair<size_t /*dim*/, VectorSimilarity> BaseVectorIndex::Info() const {
480-
return {dim_, sim_};
481-
}
482-
483-
bool BaseVectorIndex::Add(DocId id, const DocumentAccessor& doc, std::string_view field) {
480+
template <typename T>
481+
bool BaseVectorIndex<T>::Add(T id, const DocumentAccessor& doc, std::string_view field) {
484482
auto vector = doc.GetVector(field);
485483
if (!vector)
486484
return false;
@@ -494,36 +492,40 @@ bool BaseVectorIndex::Add(DocId id, const DocumentAccessor& doc, std::string_vie
494492
return true;
495493
}
496494

497-
FlatVectorIndex::FlatVectorIndex(const SchemaField::VectorParams& params,
498-
PMR_NS::memory_resource* mr)
499-
: BaseVectorIndex{params.dim, params.sim}, entries_{mr} {
495+
template <typename T>
496+
FlatVectorIndex<T>::FlatVectorIndex(const SchemaField::VectorParams& params,
497+
PMR_NS::memory_resource* mr)
498+
: BaseVectorIndex<T>{params.dim, params.sim}, entries_{mr} {
500499
DCHECK(!params.use_hnsw);
501500
entries_.reserve(params.capacity * params.dim);
502501
}
503502

504-
void FlatVectorIndex::AddVector(DocId id, const VectorPtr& vector) {
505-
DCHECK_LE(id * dim_, entries_.size());
506-
if (id * dim_ == entries_.size())
507-
entries_.resize((id + 1) * dim_);
503+
template <typename T>
504+
void FlatVectorIndex<T>::AddVector(T id, const typename BaseVectorIndex<T>::VectorPtr& vector) {
505+
DCHECK_LE(id * BaseVectorIndex<T>::dim_, entries_.size());
506+
if (id * BaseVectorIndex<T>::dim_ == entries_.size())
507+
entries_.resize((id + 1) * BaseVectorIndex<T>::dim_);
508508

509509
// TODO: Let get vector write to buf itself
510510
if (vector) {
511-
memcpy(&entries_[id * dim_], vector.get(), dim_ * sizeof(float));
511+
memcpy(&entries_[id * BaseVectorIndex<T>::dim_], vector.get(),
512+
BaseVectorIndex<T>::dim_ * sizeof(float));
512513
}
513514
}
514515

515-
void FlatVectorIndex::Remove(DocId id, const DocumentAccessor& doc, string_view field) {
516+
template <typename T>
517+
void FlatVectorIndex<T>::Remove(T id, const DocumentAccessor& doc, string_view field) {
516518
// noop
517519
}
518520

519-
const float* FlatVectorIndex::Get(DocId doc) const {
521+
template <typename T> const float* FlatVectorIndex<T>::Get(T doc) const {
520522
return &entries_[doc * dim_];
521523
}
522524

523-
std::vector<DocId> FlatVectorIndex::GetAllDocsWithNonNullValues() const {
525+
template <typename T> std::vector<T> FlatVectorIndex<T>::GetAllDocsWithNonNullValues() const {
524526
std::vector<DocId> result;
525527

526-
size_t num_vectors = entries_.size() / dim_;
528+
size_t num_vectors = entries_.size() / BaseVectorIndex<T>::dim_;
527529
result.reserve(num_vectors);
528530

529531
for (DocId id = 0; id < num_vectors; ++id) {
@@ -533,7 +535,7 @@ std::vector<DocId> FlatVectorIndex::GetAllDocsWithNonNullValues() const {
533535
bool is_zero_vector = true;
534536

535537
// TODO: Consider don't use check for zero vector
536-
for (size_t i = 0; i < dim_; ++i) {
538+
for (size_t i = 0; i < BaseVectorIndex<T>::dim_; ++i) {
537539
if (vec[i] != 0.0f) { // TODO: Consider using a threshold for float comparison
538540
is_zero_vector = false;
539541
break;
@@ -550,6 +552,8 @@ std::vector<DocId> FlatVectorIndex::GetAllDocsWithNonNullValues() const {
550552
return result;
551553
}
552554

555+
template struct FlatVectorIndex<DocId>;
556+
553557
struct HnswlibAdapter {
554558
// Default setting of hnswlib/hnswalg
555559
constexpr static size_t kDefaultEfRuntime = 10;
@@ -623,35 +627,43 @@ struct HnswlibAdapter {
623627
hnswlib::HierarchicalNSW<float> world_;
624628
};
625629

626-
HnswVectorIndex::HnswVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource*)
627-
: BaseVectorIndex{params.dim, params.sim}, adapter_{make_unique<HnswlibAdapter>(params)} {
630+
template <typename T>
631+
HnswVectorIndex<T>::HnswVectorIndex(const SchemaField::VectorParams& params,
632+
PMR_NS::memory_resource*)
633+
: BaseVectorIndex<T>{params.dim, params.sim}, adapter_{make_unique<HnswlibAdapter>(params)} {
628634
DCHECK(params.use_hnsw);
629635
// TODO: Patch hnsw to use MR
630636
}
631-
632-
HnswVectorIndex::~HnswVectorIndex() {
637+
template <typename T> HnswVectorIndex<T>::~HnswVectorIndex() {
633638
}
634639

635-
void HnswVectorIndex::AddVector(DocId id, const VectorPtr& vector) {
640+
template <typename T>
641+
void HnswVectorIndex<T>::AddVector(T id, const typename BaseVectorIndex<T>::VectorPtr& vector) {
636642
if (vector) {
637643
adapter_->Add(vector.get(), id);
638644
}
639645
}
640646

641-
std::vector<std::pair<float, DocId>> HnswVectorIndex::Knn(float* target, size_t k,
642-
std::optional<size_t> ef) const {
647+
template <typename T>
648+
std::vector<std::pair<float, T>> HnswVectorIndex<T>::Knn(float* target, size_t k,
649+
std::optional<size_t> ef) const {
643650
return adapter_->Knn(target, k, ef);
644651
}
645-
std::vector<std::pair<float, DocId>> HnswVectorIndex::Knn(float* target, size_t k,
646-
std::optional<size_t> ef,
647-
const std::vector<DocId>& allowed) const {
652+
653+
template <typename T>
654+
std::vector<std::pair<float, T>> HnswVectorIndex<T>::Knn(float* target, size_t k,
655+
std::optional<size_t> ef,
656+
const std::vector<T>& allowed) const {
648657
return adapter_->Knn(target, k, ef, allowed);
649658
}
650659

651-
void HnswVectorIndex::Remove(DocId id, const DocumentAccessor& doc, string_view field) {
660+
template <typename T>
661+
void HnswVectorIndex<T>::Remove(T id, const DocumentAccessor& doc, string_view field) {
652662
adapter_->Remove(id);
653663
}
654664

665+
template struct HnswVectorIndex<DocId>;
666+
655667
GeoIndex::GeoIndex(PMR_NS::memory_resource* mr) : rtree_(make_unique<rtree>()) {
656668
}
657669

src/core/search/indices.h

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ namespace dfly::search {
3939

4040
// Index for integer fields.
4141
// Range bounds are queried in logarithmic time, iteration is constant.
42-
struct NumericIndex : public BaseIndex {
42+
struct NumericIndex : public BaseIndex<DocId> {
4343
// Temporary base class for range tree.
4444
// It is used to use two different range trees depending on the flag use_range_tree.
4545
// If the flag is true, RangeTree is used, otherwise a simple implementation with btree_set.
@@ -76,7 +76,7 @@ struct NumericIndex : public BaseIndex {
7676
};
7777

7878
// Base index for string based indices.
79-
template <typename C> struct BaseStringIndex : public BaseIndex {
79+
template <typename C> struct BaseStringIndex : public BaseIndex<DocId> {
8080
using Container = BlockList<C>;
8181
using VecOrPtr = std::variant<std::vector<DocId>, const Container*>;
8282

@@ -157,65 +157,72 @@ struct TagIndex : public BaseStringIndex<SortedVector<DocId>> {
157157
char separator_;
158158
};
159159

160-
struct BaseVectorIndex : public BaseIndex {
161-
std::pair<size_t /*dim*/, VectorSimilarity> Info() const;
160+
template <typename T> struct BaseVectorIndex : public BaseIndex<T> {
161+
std::pair<size_t /*dim*/, VectorSimilarity> Info() const {
162+
return {dim_, sim_};
163+
}
162164

163-
bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override final;
165+
bool Add(T id, const DocumentAccessor& doc, std::string_view field) override final;
164166

165167
protected:
166168
BaseVectorIndex(size_t dim, VectorSimilarity sim);
167169

168170
using VectorPtr = decltype(std::declval<OwnedFtVector>().first);
169-
virtual void AddVector(DocId id, const VectorPtr& vector) = 0;
171+
virtual void AddVector(T id, const VectorPtr& vector) = 0;
170172

171173
size_t dim_;
172174
VectorSimilarity sim_;
173175
};
174176

175177
// Index for vector fields.
176178
// Only supports lookup by id.
177-
struct FlatVectorIndex : public BaseVectorIndex {
179+
template <typename T> struct FlatVectorIndex : public BaseVectorIndex<T> {
178180
FlatVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource* mr);
179181

180-
void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override;
182+
void Remove(T id, const DocumentAccessor& doc, std::string_view field) override;
181183

182-
const float* Get(DocId doc) const;
184+
const float* Get(T doc) const;
183185

184186
// Return all documents that have vectors in this index
185-
std::vector<DocId> GetAllDocsWithNonNullValues() const override;
187+
std::vector<T> GetAllDocsWithNonNullValues() const override;
186188

187189
protected:
188-
void AddVector(DocId id, const VectorPtr& vector) override;
190+
using BaseVectorIndex<T>::dim_;
191+
void AddVector(T id, const typename BaseVectorIndex<T>::VectorPtr& vector) override;
189192

190193
private:
191194
PMR_NS::vector<float> entries_;
192195
};
193196

197+
extern template struct FlatVectorIndex<DocId>;
198+
194199
struct HnswlibAdapter;
195200

196-
struct HnswVectorIndex : public BaseVectorIndex {
201+
template <typename T> struct HnswVectorIndex : public BaseVectorIndex<T> {
197202
HnswVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource* mr);
198203
~HnswVectorIndex();
199204

200-
void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override;
205+
void Remove(T id, const DocumentAccessor& doc, std::string_view field) override;
201206

202-
std::vector<std::pair<float, DocId>> Knn(float* target, size_t k, std::optional<size_t> ef) const;
203-
std::vector<std::pair<float, DocId>> Knn(float* target, size_t k, std::optional<size_t> ef,
204-
const std::vector<DocId>& allowed) const;
207+
std::vector<std::pair<float, T>> Knn(float* target, size_t k, std::optional<size_t> ef) const;
208+
std::vector<std::pair<float, T>> Knn(float* target, size_t k, std::optional<size_t> ef,
209+
const std::vector<T>& allowed) const;
205210

206211
// TODO: Implement if needed
207-
std::vector<DocId> GetAllDocsWithNonNullValues() const override {
208-
return std::vector<DocId>{};
212+
std::vector<T> GetAllDocsWithNonNullValues() const override {
213+
return std::vector<T>{};
209214
}
210215

211216
protected:
212-
void AddVector(DocId id, const VectorPtr& vector) override;
217+
void AddVector(T id, const typename BaseVectorIndex<T>::VectorPtr& vector) override;
213218

214219
private:
215220
std::unique_ptr<HnswlibAdapter> adapter_;
216221
};
217222

218-
struct GeoIndex : public BaseIndex {
223+
extern template struct HnswVectorIndex<DocId>;
224+
225+
struct GeoIndex : public BaseIndex<DocId> {
219226
using point =
220227
boost::geometry::model::point<double, 2,
221228
boost::geometry::cs::geographic<boost::geometry::degree>>;

0 commit comments

Comments
 (0)