Skip to content

Commit 669270b

Browse files
committed
Modify FlatVectorIndex to work with global vector index
1 parent 63f4b49 commit 669270b

File tree

2 files changed

+172
-104
lines changed

2 files changed

+172
-104
lines changed

src/core/search/indices.cc

Lines changed: 117 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@
1111
#include <absl/strings/str_split.h>
1212

1313
#include <boost/iterator/function_output_iterator.hpp>
14+
#include <shared_mutex>
15+
16+
#include "core/search/base.h"
17+
#include "core/search/vector_utils.h"
18+
#include "util/fibers/synchronization.h"
1419

1520
#define UNI_ALGO_DISABLE_NFKC_NFKD
1621

@@ -128,6 +133,16 @@ std::optional<GeoIndex::point> GetGeoPoint(const DocumentAccessor& doc, string_v
128133
return GeoIndex::point{lon, lat};
129134
}
130135

136+
template <typename Q, typename T = GlobalDocId> vector<pair<float, T>> QueueToVec(Q queue) {
137+
vector<pair<float, T>> out(queue.size());
138+
size_t idx = out.size();
139+
while (!queue.empty()) {
140+
out[--idx] = queue.top();
141+
queue.pop();
142+
}
143+
return out;
144+
}
145+
131146
}; // namespace
132147

133148
class RangeTreeAdapter : public NumericIndex::RangeTreeBase {
@@ -492,69 +507,73 @@ bool BaseVectorIndex<T>::Add(T id, const DocumentAccessor& doc, std::string_view
492507
return true;
493508
}
494509

495-
template <typename T>
496-
FlatVectorIndex<T>::FlatVectorIndex(const SchemaField::VectorParams& params,
497-
PMR_NS::memory_resource* mr)
498-
: BaseVectorIndex<T>{params.dim, params.sim}, entries_{mr} {
499-
DCHECK(!params.use_hnsw);
500-
entries_.reserve(params.capacity * params.dim);
510+
ShardNoOpVectorIndex::ShardNoOpVectorIndex(const SchemaField::VectorParams& params)
511+
: BaseVectorIndex<DocId>{params.dim, params.sim} {
501512
}
502513

503-
template <typename T>
504-
void FlatVectorIndex<T>::AddVector(T id, const typename BaseVectorIndex<T>::VectorPtr& vector) {
505-
DCHECK_LE(id * BaseVectorIndex<T>::dim_, entries_.size());
506-
if (id * BaseVectorIndex<T>::dim_ == entries_.size())
507-
entries_.resize((id + 1) * BaseVectorIndex<T>::dim_);
514+
FlatVectorIndex::FlatVectorIndex(const SchemaField::VectorParams& params, ShardId shard_set_size,
515+
PMR_NS::memory_resource* mr)
516+
: BaseVectorIndex<GlobalDocId>{params.dim, params.sim},
517+
entries_{mr},
518+
shard_vector_locks_(shard_set_size) {
519+
DCHECK(!params.use_hnsw);
520+
entries_.resize(shard_set_size);
521+
for (size_t i = 0; i < shard_set_size; i++) {
522+
entries_[i].reserve(params.capacity * params.dim);
523+
}
524+
}
508525

509-
// TODO: Let get vector write to buf itself
526+
void FlatVectorIndex::AddVector(GlobalDocId id,
527+
const typename BaseVectorIndex<GlobalDocId>::VectorPtr& vector) {
528+
auto shard_id = search::GlobalDocIdShardId(id);
529+
auto shard_doc_id = search::GlobalDocIdLocalId(id);
530+
DCHECK_LE(shard_doc_id * BaseVectorIndex<GlobalDocId>::dim_, entries_[shard_id].size());
531+
if (shard_doc_id * BaseVectorIndex<GlobalDocId>::dim_ == entries_[shard_id].size()) {
532+
unique_lock<util::fb2::SharedMutex> lock{shard_vector_locks_[shard_id]};
533+
entries_[shard_id].resize((shard_doc_id + 1) * BaseVectorIndex<GlobalDocId>::dim_);
534+
}
510535
if (vector) {
511-
memcpy(&entries_[id * BaseVectorIndex<T>::dim_], vector.get(),
512-
BaseVectorIndex<T>::dim_ * sizeof(float));
536+
memcpy(&entries_[shard_id][shard_doc_id * BaseVectorIndex<GlobalDocId>::dim_], vector.get(),
537+
BaseVectorIndex<GlobalDocId>::dim_ * sizeof(float));
513538
}
514539
}
515540

516-
template <typename T>
517-
void FlatVectorIndex<T>::Remove(T id, const DocumentAccessor& doc, string_view field) {
541+
void FlatVectorIndex::Remove(GlobalDocId id, const DocumentAccessor& doc, string_view field) {
518542
// noop
519543
}
520544

521-
template <typename T> const float* FlatVectorIndex<T>::Get(T doc) const {
522-
return &entries_[doc * dim_];
523-
}
545+
std::vector<std::pair<float, GlobalDocId>> FlatVectorIndex::Knn(float* target, size_t k) const {
546+
std::priority_queue<std::pair<float, search::GlobalDocId>> queue;
524547

525-
template <typename T> std::vector<T> FlatVectorIndex<T>::GetAllDocsWithNonNullValues() const {
526-
std::vector<DocId> result;
548+
for (size_t shard_id = 0; shard_id < entries_.size(); shard_id++) {
549+
shared_lock<util::fb2::SharedMutex> lock{shard_vector_locks_[shard_id]};
550+
size_t num_vectors = entries_[shard_id].size() / BaseVectorIndex<GlobalDocId>::dim_;
551+
for (GlobalDocId id = 0; id < num_vectors; ++id) {
552+
const float* vec = &entries_[shard_id][id * dim_];
553+
float dist = VectorDistance(target, vec, dim_, sim_);
554+
queue.emplace(dist, CreateGlobalDocId(shard_id, id));
555+
}
556+
}
527557

528-
size_t num_vectors = entries_.size() / BaseVectorIndex<T>::dim_;
529-
result.reserve(num_vectors);
558+
return QueueToVec(queue);
559+
}
530560

531-
for (DocId id = 0; id < num_vectors; ++id) {
532-
// Check if the vector is not zero (all elements are 0)
533-
// TODO: Valid vector can contain 0s, we should use a better approach
534-
const float* vec = Get(id);
535-
bool is_zero_vector = true;
561+
std::vector<std::pair<float, GlobalDocId>> FlatVectorIndex::Knn(
562+
float* target, size_t k, const std::vector<FilterShardDocs>& allowed_docs) const {
563+
std::priority_queue<std::pair<float, search::GlobalDocId>> queue;
536564

537-
// TODO: Consider don't use check for zero vector
538-
for (size_t i = 0; i < BaseVectorIndex<T>::dim_; ++i) {
539-
if (vec[i] != 0.0f) { // TODO: Consider using a threshold for float comparison
540-
is_zero_vector = false;
541-
break;
542-
}
543-
}
544-
545-
if (!is_zero_vector) {
546-
result.push_back(id);
565+
for (size_t shard_id = 0; shard_id < allowed_docs.size(); shard_id++) {
566+
shared_lock<util::fb2::SharedMutex> lock{shard_vector_locks_[shard_id]};
567+
for (auto& shard_doc_id : allowed_docs[shard_id]) {
568+
const float* vec = &entries_[shard_id][shard_doc_id * dim_];
569+
float dist = VectorDistance(target, vec, dim_, sim_);
570+
queue.emplace(dist, CreateGlobalDocId(shard_id, shard_doc_id));
547571
}
548572
}
549-
550-
// Result is already sorted by id, no need to sort again
551-
// Also it has no duplicates
552-
return result;
573+
return QueueToVec(queue);
553574
}
554575

555-
template struct FlatVectorIndex<DocId>;
556-
557-
struct HnswlibAdapter {
576+
template <typename T> struct HnswlibAdapter {
558577
// Default setting of hnswlib/hnswalg
559578
constexpr static size_t kDefaultEfRuntime = 10;
560579

@@ -564,34 +583,45 @@ struct HnswlibAdapter {
564583
100 /* seed*/} {
565584
}
566585

567-
void Add(const float* data, DocId id) {
568-
if (world_.cur_element_count + 1 >= world_.max_elements_)
569-
world_.resizeIndex(world_.cur_element_count * 2);
570-
world_.addPoint(data, id);
586+
void Add(const float* data, T id) {
587+
while (true) {
588+
try {
589+
absl::ReaderMutexLock lock(&resize_mutex_);
590+
world_.addPoint(data, id);
591+
return;
592+
} catch (const std::exception& e) {
593+
std::string error_msg = e.what();
594+
if (absl::StrContains(error_msg, "The number of elements exceeds the specified limit")) {
595+
ResizeIfFull();
596+
continue;
597+
}
598+
throw e;
599+
}
600+
}
571601
}
572602

573-
void Remove(DocId id) {
603+
void Remove(T id) {
574604
try {
575605
world_.markDelete(id);
576606
} catch (const std::exception& e) {
577607
}
578608
}
579609

580-
vector<pair<float, DocId>> Knn(float* target, size_t k, std::optional<size_t> ef) {
610+
vector<pair<float, T>> Knn(float* target, size_t k, std::optional<size_t> ef) {
581611
world_.setEf(ef.value_or(kDefaultEfRuntime));
582612
return QueueToVec(world_.searchKnn(target, k));
583613
}
584614

585-
vector<pair<float, DocId>> Knn(float* target, size_t k, std::optional<size_t> ef,
586-
const vector<DocId>& allowed) {
615+
vector<pair<float, T>> Knn(float* target, size_t k, std::optional<size_t> ef,
616+
const vector<T>& allowed) {
587617
struct BinsearchFilter : hnswlib::BaseFilterFunctor {
588618
virtual bool operator()(hnswlib::labeltype id) {
589619
return binary_search(allowed->begin(), allowed->end(), id);
590620
}
591621

592-
BinsearchFilter(const vector<DocId>* allowed) : allowed{allowed} {
622+
BinsearchFilter(const vector<T>* allowed) : allowed{allowed} {
593623
}
594-
const vector<DocId>* allowed;
624+
const vector<T>* allowed;
595625
};
596626

597627
world_.setEf(ef.value_or(kDefaultEfRuntime));
@@ -613,57 +643,64 @@ struct HnswlibAdapter {
613643
return visit([](auto& space) -> hnswlib::SpaceInterface<float>* { return &space; }, space_);
614644
}
615645

616-
template <typename Q> static vector<pair<float, DocId>> QueueToVec(Q queue) {
617-
vector<pair<float, DocId>> out(queue.size());
618-
size_t idx = out.size();
619-
while (!queue.empty()) {
620-
out[--idx] = queue.top();
621-
queue.pop();
646+
void ResizeIfFull() {
647+
{
648+
absl::ReaderMutexLock lock(&resize_mutex_);
649+
if (world_.getCurrentElementCount() < world_.getMaxElements() ||
650+
(world_.allow_replace_deleted_ && world_.getDeletedCount() > 0)) {
651+
return;
652+
}
653+
}
654+
try {
655+
absl::WriterMutexLock lock(&resize_mutex_);
656+
if (world_.getCurrentElementCount() == world_.getMaxElements() &&
657+
(!world_.allow_replace_deleted_ || world_.getDeletedCount() == 0)) {
658+
auto max_elements = world_.getMaxElements();
659+
world_.resizeIndex(max_elements * 2);
660+
LOG(INFO) << "Resizing HNSW Index, current size: " << max_elements
661+
<< ", expand by: " << max_elements * 2;
662+
}
663+
} catch (const std::exception& e) {
664+
throw e;
622665
}
623-
return out;
624666
}
625667

626668
SpaceUnion space_;
627669
hnswlib::HierarchicalNSW<float> world_;
670+
absl::Mutex resize_mutex_;
628671
};
629672

630-
template <typename T>
631-
HnswVectorIndex<T>::HnswVectorIndex(const SchemaField::VectorParams& params,
632-
PMR_NS::memory_resource*)
633-
: BaseVectorIndex<T>{params.dim, params.sim}, adapter_{make_unique<HnswlibAdapter>(params)} {
673+
HnswVectorIndex::HnswVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource*)
674+
: BaseVectorIndex<GlobalDocId>{params.dim, params.sim},
675+
adapter_{make_unique<HnswlibAdapter<GlobalDocId>>(params)} {
634676
DCHECK(params.use_hnsw);
635677
// TODO: Patch hnsw to use MR
636678
}
637-
template <typename T> HnswVectorIndex<T>::~HnswVectorIndex() {
679+
HnswVectorIndex::~HnswVectorIndex() {
638680
}
639681

640-
template <typename T>
641-
void HnswVectorIndex<T>::AddVector(T id, const typename BaseVectorIndex<T>::VectorPtr& vector) {
682+
void HnswVectorIndex::AddVector(GlobalDocId id,
683+
const typename BaseVectorIndex<GlobalDocId>::VectorPtr& vector) {
642684
if (vector) {
643685
adapter_->Add(vector.get(), id);
644686
}
645687
}
646688

647-
template <typename T>
648-
std::vector<std::pair<float, T>> HnswVectorIndex<T>::Knn(float* target, size_t k,
649-
std::optional<size_t> ef) const {
689+
std::vector<std::pair<float, GlobalDocId>> HnswVectorIndex::Knn(float* target, size_t k,
690+
std::optional<size_t> ef) const {
650691
return adapter_->Knn(target, k, ef);
651692
}
652693

653-
template <typename T>
654-
std::vector<std::pair<float, T>> HnswVectorIndex<T>::Knn(float* target, size_t k,
655-
std::optional<size_t> ef,
656-
const std::vector<T>& allowed) const {
694+
std::vector<std::pair<float, GlobalDocId>> HnswVectorIndex::Knn(
695+
float* target, size_t k, std::optional<size_t> ef,
696+
const std::vector<GlobalDocId>& allowed) const {
657697
return adapter_->Knn(target, k, ef, allowed);
658698
}
659699

660-
template <typename T>
661-
void HnswVectorIndex<T>::Remove(T id, const DocumentAccessor& doc, string_view field) {
700+
void HnswVectorIndex::Remove(GlobalDocId id, const DocumentAccessor& doc, string_view field) {
662701
adapter_->Remove(id);
663702
}
664703

665-
template struct HnswVectorIndex<DocId>;
666-
667704
GeoIndex::GeoIndex(PMR_NS::memory_resource* mr) : rtree_(make_unique<rtree>()) {
668705
}
669706

0 commit comments

Comments
 (0)