Skip to content

Commit 3a54c72

Browse files
committed
Global vector index
* Disable creation of vector index in shard * Execution flow with/without knn node * Add/Remove/Rebuild docs * Use GlobalDocId for vector index
1 parent 6480e41 commit 3a54c72

File tree

13 files changed

+558
-69
lines changed

13 files changed

+558
-69
lines changed

src/core/search/ast_expr.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ AstKnnNode::AstKnnNode(AstNode&& filter, AstKnnNode&& self) {
7373
this->filter = make_unique<AstNode>(std::move(filter));
7474
}
7575

76+
bool AstKnnNode::HasFilterNode() const {
77+
return filter == nullptr;
78+
}
79+
7680
} // namespace dfly::search
7781

7882
namespace std {

src/core/search/ast_expr.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ struct AstKnnNode {
114114
OwnedFtVector vec;
115115
std::string score_alias;
116116
std::optional<float> ef_runtime;
117+
118+
bool HasFilterNode() const;
117119
};
118120

119121
using NodeVariants =

src/core/search/base.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,24 @@
1616
#include "absl/container/flat_hash_set.h"
1717
#include "base/pmr/memory_resource.h"
1818
#include "core/string_map.h"
19+
#include "server/tx_base.h"
1920

2021
namespace dfly::search {
2122

2223
using DocId = uint32_t;
24+
using GlobalDocId = uint64_t;
25+
26+
inline GlobalDocId CreateGlobalDocId(ShardId shard_id, DocId local_doc_id) {
27+
return ((uint64_t)shard_id << 32) | local_doc_id;
28+
}
29+
30+
inline ShardId GlobalDocIdShardId(GlobalDocId id) {
31+
return (id >> 32);
32+
}
33+
34+
inline search::DocId GlobalDocIdLocalId(GlobalDocId id) {
35+
return (id)&0xFFFF;
36+
}
2337

2438
enum class VectorSimilarity { L2, IP, COSINE };
2539

src/core/search/indices.cc

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -523,12 +523,12 @@ template <typename T> const float* FlatVectorIndex<T>::Get(T doc) const {
523523
}
524524

525525
template <typename T> std::vector<T> FlatVectorIndex<T>::GetAllDocsWithNonNullValues() const {
526-
std::vector<DocId> result;
526+
std::vector<T> result;
527527

528528
size_t num_vectors = entries_.size() / BaseVectorIndex<T>::dim_;
529529
result.reserve(num_vectors);
530530

531-
for (DocId id = 0; id < num_vectors; ++id) {
531+
for (T id = 0; id < num_vectors; ++id) {
532532
// Check if the vector is not zero (all elements are 0)
533533
// TODO: Valid vector can contain 0s, we should use a better approach
534534
const float* vec = Get(id);
@@ -553,8 +553,9 @@ template <typename T> std::vector<T> FlatVectorIndex<T>::GetAllDocsWithNonNullVa
553553
}
554554

555555
template struct FlatVectorIndex<DocId>;
556+
template struct FlatVectorIndex<GlobalDocId>;
556557

557-
struct HnswlibAdapter {
558+
template <typename T> struct HnswlibAdapter {
558559
// Default setting of hnswlib/hnswalg
559560
constexpr static size_t kDefaultEfRuntime = 10;
560561

@@ -564,34 +565,45 @@ struct HnswlibAdapter {
564565
100 /* seed*/} {
565566
}
566567

567-
void Add(const float* data, DocId id) {
568-
if (world_.cur_element_count + 1 >= world_.max_elements_)
569-
world_.resizeIndex(world_.cur_element_count * 2);
570-
world_.addPoint(data, id);
568+
void Add(const float* data, T id) {
569+
while (true) {
570+
try {
571+
absl::ReaderMutexLock lock(&resize_mutex_);
572+
world_.addPoint(data, id);
573+
return;
574+
} catch (const std::exception& e) {
575+
std::string error_msg = e.what();
576+
if (absl::StrContains(error_msg, "The number of elements exceeds the specified limit")) {
577+
ResizeIfFull();
578+
continue;
579+
}
580+
throw e;
581+
}
582+
}
571583
}
572584

573-
void Remove(DocId id) {
585+
void Remove(T id) {
574586
try {
575587
world_.markDelete(id);
576588
} catch (const std::exception& e) {
577589
}
578590
}
579591

580-
vector<pair<float, DocId>> Knn(float* target, size_t k, std::optional<size_t> ef) {
592+
vector<pair<float, T>> Knn(float* target, size_t k, std::optional<size_t> ef) {
581593
world_.setEf(ef.value_or(kDefaultEfRuntime));
582594
return QueueToVec(world_.searchKnn(target, k));
583595
}
584596

585-
vector<pair<float, DocId>> Knn(float* target, size_t k, std::optional<size_t> ef,
586-
const vector<DocId>& allowed) {
597+
vector<pair<float, T>> Knn(float* target, size_t k, std::optional<size_t> ef,
598+
const vector<T>& allowed) {
587599
struct BinsearchFilter : hnswlib::BaseFilterFunctor {
588600
virtual bool operator()(hnswlib::labeltype id) {
589601
return binary_search(allowed->begin(), allowed->end(), id);
590602
}
591603

592-
BinsearchFilter(const vector<DocId>* allowed) : allowed{allowed} {
604+
BinsearchFilter(const vector<T>* allowed) : allowed{allowed} {
593605
}
594-
const vector<DocId>* allowed;
606+
const vector<T>* allowed;
595607
};
596608

597609
world_.setEf(ef.value_or(kDefaultEfRuntime));
@@ -613,8 +625,8 @@ struct HnswlibAdapter {
613625
return visit([](auto& space) -> hnswlib::SpaceInterface<float>* { return &space; }, space_);
614626
}
615627

616-
template <typename Q> static vector<pair<float, DocId>> QueueToVec(Q queue) {
617-
vector<pair<float, DocId>> out(queue.size());
628+
template <typename Q> static vector<pair<float, T>> QueueToVec(Q queue) {
629+
vector<pair<float, T>> out(queue.size());
618630
size_t idx = out.size();
619631
while (!queue.empty()) {
620632
out[--idx] = queue.top();
@@ -623,14 +635,37 @@ struct HnswlibAdapter {
623635
return out;
624636
}
625637

638+
void ResizeIfFull() {
639+
{
640+
absl::ReaderMutexLock lock(&resize_mutex_);
641+
if (world_.getCurrentElementCount() < world_.getMaxElements() ||
642+
(world_.allow_replace_deleted_ && world_.getDeletedCount() > 0)) {
643+
return;
644+
}
645+
}
646+
try {
647+
absl::WriterMutexLock lock(&resize_mutex_);
648+
if (world_.getCurrentElementCount() == world_.getMaxElements() &&
649+
(!world_.allow_replace_deleted_ || world_.getDeletedCount() == 0)) {
650+
auto max_elements = world_.getMaxElements();
651+
world_.resizeIndex(max_elements * 2);
652+
LOG(INFO) << "Resizing HNSW Index, current size: " << max_elements
653+
<< ", expand by: " << max_elements * 2;
654+
}
655+
} catch (const std::exception& e) {
656+
throw e;
657+
}
658+
}
659+
626660
SpaceUnion space_;
627661
hnswlib::HierarchicalNSW<float> world_;
662+
absl::Mutex resize_mutex_;
628663
};
629664

630665
template <typename T>
631666
HnswVectorIndex<T>::HnswVectorIndex(const SchemaField::VectorParams& params,
632667
PMR_NS::memory_resource*)
633-
: BaseVectorIndex<T>{params.dim, params.sim}, adapter_{make_unique<HnswlibAdapter>(params)} {
668+
: BaseVectorIndex<T>{params.dim, params.sim}, adapter_{make_unique<HnswlibAdapter<T>>(params)} {
634669
DCHECK(params.use_hnsw);
635670
// TODO: Patch hnsw to use MR
636671
}
@@ -663,6 +698,7 @@ void HnswVectorIndex<T>::Remove(T id, const DocumentAccessor& doc, string_view f
663698
}
664699

665700
template struct HnswVectorIndex<DocId>;
701+
template struct HnswVectorIndex<GlobalDocId>;
666702

667703
GeoIndex::GeoIndex(PMR_NS::memory_resource* mr) : rtree_(make_unique<rtree>()) {
668704
}

src/core/search/indices.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,9 @@ template <typename T> struct FlatVectorIndex : public BaseVectorIndex<T> {
195195
};
196196

197197
extern template struct FlatVectorIndex<DocId>;
198+
extern template struct FlatVectorIndex<GlobalDocId>;
198199

199-
struct HnswlibAdapter;
200+
template <typename T> struct HnswlibAdapter;
200201

201202
template <typename T> struct HnswVectorIndex : public BaseVectorIndex<T> {
202203
HnswVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource* mr);
@@ -217,10 +218,11 @@ template <typename T> struct HnswVectorIndex : public BaseVectorIndex<T> {
217218
void AddVector(T id, const typename BaseVectorIndex<T>::VectorPtr& vector) override;
218219

219220
private:
220-
std::unique_ptr<HnswlibAdapter> adapter_;
221+
std::unique_ptr<HnswlibAdapter<T>> adapter_;
221222
};
222223

223224
extern template struct HnswVectorIndex<DocId>;
225+
extern template struct HnswVectorIndex<GlobalDocId>;
224226

225227
struct GeoIndex : public BaseIndex<DocId> {
226228
using point =

src/core/search/search.cc

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -503,24 +503,13 @@ void FieldIndices::CreateIndices(PMR_NS::memory_resource* mr) {
503503
indices_[field_ident] = make_unique<TagIndex>(mr, tparams);
504504
break;
505505
}
506-
case SchemaField::VECTOR: {
507-
unique_ptr<BaseVectorIndex<DocId>> vector_index;
508-
509-
DCHECK(holds_alternative<SchemaField::VectorParams>(field_info.special_params));
510-
const auto& vparams = std::get<SchemaField::VectorParams>(field_info.special_params);
511-
512-
if (vparams.use_hnsw)
513-
vector_index = make_unique<HnswVectorIndex<DocId>>(vparams, mr);
514-
else
515-
vector_index = make_unique<FlatVectorIndex<DocId>>(vparams, mr);
516-
517-
indices_[field_ident] = std::move(vector_index);
518-
break;
519-
}
520506
case SchemaField::GEO: {
521507
indices_[field_ident] = make_unique<GeoIndex>(mr);
522508
break;
523509
}
510+
// We have global vector index
511+
case SchemaField::VECTOR:
512+
break;
524513
}
525514
}
526515
}
@@ -666,14 +655,21 @@ SearchResult SearchAlgorithm::Search(const FieldIndices* index, size_t cuttoff_l
666655
return bs.Search(*query_, cuttoff_limit);
667656
}
668657

669-
optional<KnnScoreSortOption> SearchAlgorithm::GetKnnScoreSortOption() const {
670-
DCHECK(query_);
671-
672-
// KNN query
673-
if (auto* knn = get_if<AstKnnNode>(query_.get()); knn)
674-
return KnnScoreSortOption{string_view{knn->score_alias}, knn->limit};
658+
bool SearchAlgorithm::IsKnnQuery() const {
659+
return std::holds_alternative<AstKnnNode>(*query_);
660+
}
675661

676-
return nullopt;
662+
std::unique_ptr<AstNode> SearchAlgorithm::GetKnnNode() {
663+
if (auto* knn = get_if<AstKnnNode>(query_.get()); knn) {
664+
// Save knn score sort option
665+
knn_score_sort_option_ = KnnScoreSortOption{string_view{knn->score_alias}, knn->limit};
666+
auto node = std::move(query_);
667+
if (!std::holds_alternative<AstStarNode>(*(knn)->filter))
668+
query_.swap(knn->filter);
669+
return node;
670+
}
671+
LOG(DFATAL) << "Should not reach here";
672+
return nullptr;
677673
}
678674

679675
void SearchAlgorithm::EnableProfiling() {

src/core/search/search.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ namespace dfly::search {
2323

2424
struct AstNode;
2525
struct TextIndex;
26+
struct AstKnnNode;
2627

2728
// Optional FILTER
2829
struct OptionalNumericFilter : public OptionalFilterBase {
@@ -201,14 +202,20 @@ class SearchAlgorithm {
201202
SearchResult Search(const FieldIndices* index,
202203
size_t cuttoff_limit = std::numeric_limits<size_t>::max()) const;
203204

204-
// if enabled, return limit & alias for knn query
205-
std::optional<KnnScoreSortOption> GetKnnScoreSortOption() const;
205+
std::optional<KnnScoreSortOption> GetKnnScoreSortOption() const {
206+
return knn_score_sort_option_;
207+
}
208+
209+
bool IsKnnQuery() const;
210+
211+
std::unique_ptr<AstNode> GetKnnNode();
206212

207213
void EnableProfiling();
208214

209215
private:
210216
bool profiling_enabled_ = false;
211217
std::unique_ptr<AstNode> query_;
218+
std::optional<KnnScoreSortOption> knn_score_sort_option_;
212219
};
213220

214221
} // namespace dfly::search

src/server/search/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ if (NOT WITH_SEARCH)
44
return()
55
endif()
66

7-
add_library(dfly_search_server aggregator.cc doc_accessors.cc doc_index.cc search_family.cc index_join.cc
7+
add_library(dfly_search_server aggregator.cc doc_accessors.cc doc_index.cc search_family.cc index_join.cc global_vector_index.cc
88
../cluster/coordinator.cc)
99
target_link_libraries(dfly_search_server dfly_transaction dragonfly_lib dfly_facade redis_lib jsonpath TRDP::jsoncons)
1010

0 commit comments

Comments
 (0)