Skip to content

Commit e65709a

Browse files
committed
WIP
1 parent 3a54c72 commit e65709a

File tree

8 files changed

+373
-312
lines changed

8 files changed

+373
-312
lines changed

src/core/search/indices.cc

Lines changed: 110 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@
1111
#include <absl/strings/str_split.h>
1212

1313
#include <boost/iterator/function_output_iterator.hpp>
14+
#include <shared_mutex>
15+
16+
#include "core/search/base.h"
17+
#include "core/search/vector_utils.h"
18+
#include "util/fibers/synchronization.h"
1419

1520
#define UNI_ALGO_DISABLE_NFKC_NFKD
1621

@@ -492,58 +497,116 @@ bool BaseVectorIndex<T>::Add(T id, const DocumentAccessor& doc, std::string_view
492497
return true;
493498
}
494499

495-
template <typename T>
496-
FlatVectorIndex<T>::FlatVectorIndex(const SchemaField::VectorParams& params,
497-
PMR_NS::memory_resource* mr)
498-
: BaseVectorIndex<T>{params.dim, params.sim}, entries_{mr} {
499-
DCHECK(!params.use_hnsw);
500-
entries_.reserve(params.capacity * params.dim);
500+
ShardNoOpVectorIndex::ShardNoOpVectorIndex(const SchemaField::VectorParams& params)
501+
: BaseVectorIndex<DocId>{params.dim, params.sim} {
501502
}
502503

503-
template <typename T>
504-
void FlatVectorIndex<T>::AddVector(T id, const typename BaseVectorIndex<T>::VectorPtr& vector) {
505-
DCHECK_LE(id * BaseVectorIndex<T>::dim_, entries_.size());
506-
if (id * BaseVectorIndex<T>::dim_ == entries_.size())
507-
entries_.resize((id + 1) * BaseVectorIndex<T>::dim_);
504+
FlatVectorIndex::FlatVectorIndex(const SchemaField::VectorParams& params, ShardId shard_set_size,
505+
PMR_NS::memory_resource* mr)
506+
: BaseVectorIndex<GlobalDocId>{params.dim, params.sim},
507+
entries_{mr},
508+
shard_vector_locks_(shard_set_size) {
509+
DCHECK(!params.use_hnsw);
510+
entries_.resize(shard_set_size);
511+
for (size_t i = 0; i < shard_set_size; i++) {
512+
entries_[i].resize(params.capacity * params.dim);
513+
}
514+
}
508515

509-
// TODO: Let get vector write to buf itself
516+
void FlatVectorIndex::AddVector(GlobalDocId id,
517+
const typename BaseVectorIndex<GlobalDocId>::VectorPtr& vector) {
518+
auto shard_id = search::GlobalDocIdShardId(id);
519+
auto shard_doc_id = search::GlobalDocIdLocalId(id);
520+
if (shard_doc_id * BaseVectorIndex<GlobalDocId>::dim_ == entries_[shard_id].size()) {
521+
unique_lock<util::fb2::SharedMutex> lock{shard_vector_locks_[shard_id]};
522+
entries_[shard_id].resize((shard_doc_id + 1) * BaseVectorIndex<GlobalDocId>::dim_);
523+
}
510524
if (vector) {
511-
memcpy(&entries_[id * BaseVectorIndex<T>::dim_], vector.get(),
512-
BaseVectorIndex<T>::dim_ * sizeof(float));
525+
memcpy(&entries_[shard_id][shard_doc_id * BaseVectorIndex<GlobalDocId>::dim_], vector.get(),
526+
BaseVectorIndex<GlobalDocId>::dim_ * sizeof(float));
513527
}
514528
}
515529

516-
template <typename T>
517-
void FlatVectorIndex<T>::Remove(T id, const DocumentAccessor& doc, string_view field) {
530+
void FlatVectorIndex::Remove(GlobalDocId id, const DocumentAccessor& doc, string_view field) {
518531
// noop
519532
}
520533

521-
template <typename T> const float* FlatVectorIndex<T>::Get(T doc) const {
522-
return &entries_[doc * dim_];
534+
const float* FlatVectorIndex::Get(GlobalDocId doc) const {
535+
ShardId shard_id = search::GlobalDocIdShardId(doc);
536+
shared_lock<util::fb2::SharedMutex> lock{shard_vector_locks_[shard_id]};
537+
return &entries_[shard_id][search::GlobalDocIdLocalId(doc) * dim_];
538+
}
539+
540+
std::vector<std::pair<float, GlobalDocId>> FlatVectorIndex::Knn(float* target) const {
541+
std::priority_queue<std::pair<float, search::GlobalDocId>> queue;
542+
543+
for (size_t shard_id = 0; shard_id < entries_.size(); shard_id++) {
544+
shared_lock<util::fb2::SharedMutex> lock{shard_vector_locks_[shard_id]};
545+
size_t num_vectors = entries_[shard_id].size() / BaseVectorIndex<GlobalDocId>::dim_;
546+
for (GlobalDocId id = 0; id < num_vectors; ++id) {
547+
// Check if the vector is not zero (all elements are 0)
548+
// TODO: Valid vector can contain 0s, we should use a better approach
549+
const float* vec = &entries_[shard_id][id * dim_];
550+
float dist = VectorDistance(target, vec, dim_, sim_);
551+
queue.emplace(dist, CreateGlobalDocId(shard_id, id));
552+
}
553+
}
554+
555+
vector<pair<float, search::GlobalDocId>> out(queue.size());
556+
size_t idx = out.size();
557+
while (!queue.empty()) {
558+
out[--idx] = queue.top();
559+
queue.pop();
560+
}
561+
562+
return out;
523563
}
524564

525-
template <typename T> std::vector<T> FlatVectorIndex<T>::GetAllDocsWithNonNullValues() const {
526-
std::vector<T> result;
565+
std::vector<std::pair<float, GlobalDocId>> FlatVectorIndex::Knn(
566+
float* target, const std::vector<GlobalDocId>& allowed) const {
567+
std::priority_queue<std::pair<float, search::GlobalDocId>> queue;
568+
569+
for (auto& doc : allowed) {
570+
uint16_t shard_id = search::GlobalDocIdShardId(doc);
571+
auto shard_doc_id = search::GlobalDocIdLocalId(doc);
572+
shared_lock<util::fb2::SharedMutex> lock{shard_vector_locks_[shard_id]};
573+
const float* vec = &entries_[shard_id][shard_doc_id * dim_];
574+
float dist = VectorDistance(target, vec, dim_, sim_);
575+
queue.emplace(dist, doc);
576+
}
577+
578+
vector<pair<float, search::GlobalDocId>> out(queue.size());
579+
size_t idx = out.size();
580+
while (!queue.empty()) {
581+
out[--idx] = queue.top();
582+
queue.pop();
583+
}
527584

528-
size_t num_vectors = entries_.size() / BaseVectorIndex<T>::dim_;
529-
result.reserve(num_vectors);
585+
return out;
586+
}
530587

531-
for (T id = 0; id < num_vectors; ++id) {
532-
// Check if the vector is not zero (all elements are 0)
533-
// TODO: Valid vector can contain 0s, we should use a better approach
534-
const float* vec = Get(id);
535-
bool is_zero_vector = true;
588+
std::vector<GlobalDocId> FlatVectorIndex::GetAllDocsWithNonNullValues() const {
589+
std::vector<GlobalDocId> result;
590+
for (size_t shard_id = 0; shard_id < entries_.size(); shard_id++) {
591+
shared_lock<util::fb2::SharedMutex> lock{shard_vector_locks_[shard_id]};
592+
size_t num_vectors = entries_[shard_id].size() / BaseVectorIndex<GlobalDocId>::dim_;
593+
for (GlobalDocId id = 0; id < num_vectors; ++id) {
594+
// Check if the vector is not zero (all elements are 0)
595+
// TODO: Valid vector can contain 0s, we should use a better approach
596+
const float* vec = &entries_[shard_id][id * dim_];
597+
bool is_zero_vector = true;
536598

537-
// TODO: Consider don't use check for zero vector
538-
for (size_t i = 0; i < BaseVectorIndex<T>::dim_; ++i) {
539-
if (vec[i] != 0.0f) { // TODO: Consider using a threshold for float comparison
540-
is_zero_vector = false;
541-
break;
599+
// TODO: Consider don't use check for zero vector
600+
for (size_t i = 0; i < BaseVectorIndex<GlobalDocId>::dim_; ++i) {
601+
if (vec[i] != 0.0f) { // TODO: Consider using a threshold for float comparison
602+
is_zero_vector = false;
603+
break;
604+
}
542605
}
543-
}
544606

545-
if (!is_zero_vector) {
546-
result.push_back(id);
607+
if (!is_zero_vector) {
608+
result.push_back(CreateGlobalDocId(shard_id, id));
609+
}
547610
}
548611
}
549612

@@ -552,9 +615,6 @@ template <typename T> std::vector<T> FlatVectorIndex<T>::GetAllDocsWithNonNullVa
552615
return result;
553616
}
554617

555-
template struct FlatVectorIndex<DocId>;
556-
template struct FlatVectorIndex<GlobalDocId>;
557-
558618
template <typename T> struct HnswlibAdapter {
559619
// Default setting of hnswlib/hnswalg
560620
constexpr static size_t kDefaultEfRuntime = 10;
@@ -662,44 +722,37 @@ template <typename T> struct HnswlibAdapter {
662722
absl::Mutex resize_mutex_;
663723
};
664724

665-
template <typename T>
666-
HnswVectorIndex<T>::HnswVectorIndex(const SchemaField::VectorParams& params,
667-
PMR_NS::memory_resource*)
668-
: BaseVectorIndex<T>{params.dim, params.sim}, adapter_{make_unique<HnswlibAdapter<T>>(params)} {
725+
HnswVectorIndex::HnswVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource*)
726+
: BaseVectorIndex<GlobalDocId>{params.dim, params.sim},
727+
adapter_{make_unique<HnswlibAdapter<GlobalDocId>>(params)} {
669728
DCHECK(params.use_hnsw);
670729
// TODO: Patch hnsw to use MR
671730
}
672-
template <typename T> HnswVectorIndex<T>::~HnswVectorIndex() {
731+
HnswVectorIndex::~HnswVectorIndex() {
673732
}
674733

675-
template <typename T>
676-
void HnswVectorIndex<T>::AddVector(T id, const typename BaseVectorIndex<T>::VectorPtr& vector) {
734+
void HnswVectorIndex::AddVector(GlobalDocId id,
735+
const typename BaseVectorIndex<GlobalDocId>::VectorPtr& vector) {
677736
if (vector) {
678737
adapter_->Add(vector.get(), id);
679738
}
680739
}
681740

682-
template <typename T>
683-
std::vector<std::pair<float, T>> HnswVectorIndex<T>::Knn(float* target, size_t k,
684-
std::optional<size_t> ef) const {
741+
std::vector<std::pair<float, GlobalDocId>> HnswVectorIndex::Knn(float* target, size_t k,
742+
std::optional<size_t> ef) const {
685743
return adapter_->Knn(target, k, ef);
686744
}
687745

688-
template <typename T>
689-
std::vector<std::pair<float, T>> HnswVectorIndex<T>::Knn(float* target, size_t k,
690-
std::optional<size_t> ef,
691-
const std::vector<T>& allowed) const {
746+
std::vector<std::pair<float, GlobalDocId>> HnswVectorIndex::Knn(
747+
float* target, size_t k, std::optional<size_t> ef,
748+
const std::vector<GlobalDocId>& allowed) const {
692749
return adapter_->Knn(target, k, ef, allowed);
693750
}
694751

695-
template <typename T>
696-
void HnswVectorIndex<T>::Remove(T id, const DocumentAccessor& doc, string_view field) {
752+
void HnswVectorIndex::Remove(GlobalDocId id, const DocumentAccessor& doc, string_view field) {
697753
adapter_->Remove(id);
698754
}
699755

700-
template struct HnswVectorIndex<DocId>;
701-
template struct HnswVectorIndex<GlobalDocId>;
702-
703756
GeoIndex::GeoIndex(PMR_NS::memory_resource* mr) : rtree_(make_unique<rtree>()) {
704757
}
705758

src/core/search/indices.h

Lines changed: 51 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
#include <absl/container/flat_hash_map.h>
99
#include <absl/container/flat_hash_set.h>
1010

11+
//#include "server/search/global_vector_index.h"
12+
#include "util/fibers/synchronization.h"
13+
1114
// Wrong warning reported when geometry.hpp is loaded
1215
#ifndef __clang__
1316
#pragma GCC diagnostic push
@@ -174,56 +177,80 @@ template <typename T> struct BaseVectorIndex : public BaseIndex<T> {
174177
VectorSimilarity sim_;
175178
};
176179

180+
// ShardNoOpVectorIndex is used as placeholder as vector index in each shard. It doesn't implement
181+
// any functionality so adding documents will not have any effect on it. It is used to support
182+
// as filter when adding fields.
183+
struct ShardNoOpVectorIndex : public BaseVectorIndex<DocId> {
184+
explicit ShardNoOpVectorIndex(const SchemaField::VectorParams& params);
185+
186+
void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override {
187+
// noop
188+
}
189+
190+
// Return all documents that have vectors in this index
191+
std::vector<DocId> GetAllDocsWithNonNullValues() const override {
192+
return {};
193+
}
194+
195+
protected:
196+
using BaseVectorIndex<DocId>::dim_;
197+
void AddVector(DocId id, const typename BaseVectorIndex<DocId>::VectorPtr& vector) override {
198+
// noop
199+
}
200+
};
201+
177202
// Index for vector fields.
178203
// Only supports lookup by id.
179-
template <typename T> struct FlatVectorIndex : public BaseVectorIndex<T> {
180-
FlatVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource* mr);
204+
struct FlatVectorIndex : public BaseVectorIndex<GlobalDocId> {
205+
FlatVectorIndex(const SchemaField::VectorParams& params, ShardId shard_set_size,
206+
PMR_NS::memory_resource* mr);
207+
208+
void Remove(GlobalDocId id, const DocumentAccessor& doc, std::string_view field) override;
181209

182-
void Remove(T id, const DocumentAccessor& doc, std::string_view field) override;
210+
const float* Get(GlobalDocId doc) const;
183211

184-
const float* Get(T doc) const;
212+
std::vector<std::pair<float, GlobalDocId>> Knn(float* target) const;
213+
std::vector<std::pair<float, GlobalDocId>> Knn(float* target,
214+
const std::vector<GlobalDocId>& allowed) const;
185215

186216
// Return all documents that have vectors in this index
187-
std::vector<T> GetAllDocsWithNonNullValues() const override;
217+
std::vector<GlobalDocId> GetAllDocsWithNonNullValues() const override;
188218

189219
protected:
190-
using BaseVectorIndex<T>::dim_;
191-
void AddVector(T id, const typename BaseVectorIndex<T>::VectorPtr& vector) override;
220+
using BaseVectorIndex<GlobalDocId>::dim_;
221+
void AddVector(GlobalDocId id,
222+
const typename BaseVectorIndex<GlobalDocId>::VectorPtr& vector) override;
192223

193224
private:
194-
PMR_NS::vector<float> entries_;
225+
PMR_NS::vector<PMR_NS::vector<float>> entries_;
226+
mutable std::vector<util::fb2::SharedMutex> shard_vector_locks_;
195227
};
196228

197-
extern template struct FlatVectorIndex<DocId>;
198-
extern template struct FlatVectorIndex<GlobalDocId>;
199-
200229
template <typename T> struct HnswlibAdapter;
201-
202-
template <typename T> struct HnswVectorIndex : public BaseVectorIndex<T> {
230+
struct HnswVectorIndex : public BaseVectorIndex<GlobalDocId> {
203231
HnswVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource* mr);
204232
~HnswVectorIndex();
205233

206-
void Remove(T id, const DocumentAccessor& doc, std::string_view field) override;
234+
void Remove(GlobalDocId id, const DocumentAccessor& doc, std::string_view field) override;
207235

208-
std::vector<std::pair<float, T>> Knn(float* target, size_t k, std::optional<size_t> ef) const;
209-
std::vector<std::pair<float, T>> Knn(float* target, size_t k, std::optional<size_t> ef,
210-
const std::vector<T>& allowed) const;
236+
std::vector<std::pair<float, GlobalDocId>> Knn(float* target, size_t k,
237+
std::optional<size_t> ef) const;
238+
std::vector<std::pair<float, GlobalDocId>> Knn(float* target, size_t k, std::optional<size_t> ef,
239+
const std::vector<GlobalDocId>& allowed) const;
211240

212241
// TODO: Implement if needed
213-
std::vector<T> GetAllDocsWithNonNullValues() const override {
214-
return std::vector<T>{};
242+
std::vector<GlobalDocId> GetAllDocsWithNonNullValues() const override {
243+
return std::vector<GlobalDocId>{};
215244
}
216245

217246
protected:
218-
void AddVector(T id, const typename BaseVectorIndex<T>::VectorPtr& vector) override;
247+
void AddVector(GlobalDocId id,
248+
const typename BaseVectorIndex<GlobalDocId>::VectorPtr& vector) override;
219249

220250
private:
221-
std::unique_ptr<HnswlibAdapter<T>> adapter_;
251+
std::unique_ptr<HnswlibAdapter<GlobalDocId>> adapter_;
222252
};
223253

224-
extern template struct HnswVectorIndex<DocId>;
225-
extern template struct HnswVectorIndex<GlobalDocId>;
226-
227254
struct GeoIndex : public BaseIndex<DocId> {
228255
using point =
229256
boost::geometry::model::point<double, 2,

0 commit comments

Comments
 (0)