Skip to content

Commit 63f4b49

Browse files
committed
Global Vector Index
1 parent 6480e41 commit 63f4b49

File tree

12 files changed

+612
-133
lines changed

12 files changed

+612
-133
lines changed

src/core/search/ast_expr.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ AstKnnNode::AstKnnNode(AstNode&& filter, AstKnnNode&& self) {
7373
this->filter = make_unique<AstNode>(std::move(filter));
7474
}
7575

76+
bool AstKnnNode::Filter() const {
77+
return filter == nullptr;
78+
}
79+
7680
} // namespace dfly::search
7781

7882
namespace std {

src/core/search/ast_expr.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ struct AstKnnNode {
114114
OwnedFtVector vec;
115115
std::string score_alias;
116116
std::optional<float> ef_runtime;
117+
118+
bool Filter() const;
117119
};
118120

119121
using NodeVariants =

src/core/search/base.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,24 @@
1616
#include "absl/container/flat_hash_set.h"
1717
#include "base/pmr/memory_resource.h"
1818
#include "core/string_map.h"
19+
#include "server/tx_base.h"
1920

2021
namespace dfly::search {
2122

2223
using DocId = uint32_t;
24+
using GlobalDocId = uint64_t;
25+
26+
inline GlobalDocId CreateGlobalDocId(ShardId shard_id, DocId local_doc_id) {
27+
return ((uint64_t)shard_id << 32) | local_doc_id;
28+
}
29+
30+
inline ShardId GlobalDocIdShardId(GlobalDocId id) {
31+
return (id >> 32);
32+
}
33+
34+
inline search::DocId GlobalDocIdLocalId(GlobalDocId id) {
35+
return (id)&0xFFFF;
36+
}
2337

2438
enum class VectorSimilarity { L2, IP, COSINE };
2539

src/core/search/search.cc

Lines changed: 21 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -336,68 +336,10 @@ struct BasicSearch {
336336
return UnifyResults(GetSubResults(node.tags, mapping), LogicOp::OR);
337337
}
338338

339-
void SearchKnnFlat(FlatVectorIndex<DocId>* vec_index, const AstKnnNode& knn,
340-
IndexResult&& sub_results) {
341-
knn_distances_.reserve(sub_results.ApproximateSize());
342-
auto cb = [&](auto* set) {
343-
auto [dim, sim] = vec_index->Info();
344-
for (DocId matched_doc : *set) {
345-
float dist = VectorDistance(knn.vec.first.get(), vec_index->Get(matched_doc), dim, sim);
346-
knn_distances_.emplace_back(dist, matched_doc);
347-
}
348-
};
349-
visit(cb, sub_results.Borrowed());
350-
351-
size_t prefix_size = min(knn.limit, knn_distances_.size());
352-
partial_sort(knn_distances_.begin(), knn_distances_.begin() + prefix_size,
353-
knn_distances_.end());
354-
knn_distances_.resize(prefix_size);
355-
}
356-
357-
void SearchKnnHnsw(HnswVectorIndex<DocId>* vec_index, const AstKnnNode& knn,
358-
IndexResult&& sub_results) {
359-
if (indices_->GetAllDocs().size() == sub_results.ApproximateSize()) // TODO: remove approx size
360-
knn_distances_ = vec_index->Knn(knn.vec.first.get(), knn.limit, knn.ef_runtime);
361-
else
362-
knn_distances_ =
363-
vec_index->Knn(knn.vec.first.get(), knn.limit, knn.ef_runtime, sub_results.Take().first);
364-
}
365-
366339
// [KNN limit @field vec]: Compute distance from `vec` to all vectors keep closest `limit`
367340
IndexResult Search(const AstKnnNode& knn, string_view active_field) {
368-
DCHECK(active_field.empty());
369-
auto sub_results = SearchGeneric(*knn.filter, active_field);
370-
371-
auto* vec_index = GetIndex<BaseVectorIndex<DocId>>(knn.field);
372-
if (!vec_index)
373-
return IndexResult{};
374-
375-
// If vector dimension is 0, treat as placeholder/invalid - return empty results
376-
// This allows tests to use dummy vector values like "<your_vector_blob>"
377-
if (knn.vec.second == 0)
378-
return IndexResult{};
379-
380-
if (auto [dim, _] = vec_index->Info(); dim != knn.vec.second) {
381-
error_ =
382-
absl::StrCat("Wrong vector index dimensions, got: ", knn.vec.second, ", expected: ", dim);
383-
return IndexResult{};
384-
}
385-
386-
knn_scores_.clear();
387-
if (auto hnsw_index = dynamic_cast<HnswVectorIndex<DocId>*>(vec_index); hnsw_index)
388-
SearchKnnHnsw(hnsw_index, knn, std::move(sub_results));
389-
else
390-
SearchKnnFlat(dynamic_cast<FlatVectorIndex<DocId>*>(vec_index), knn, std::move(sub_results));
391-
392-
vector<DocId> out(knn_distances_.size());
393-
knn_scores_.reserve(knn_distances_.size());
394-
395-
for (size_t i = 0; i < knn_distances_.size(); i++) {
396-
knn_scores_.emplace_back(knn_distances_[i].second, knn_distances_[i].first);
397-
out[i] = knn_distances_[i].second;
398-
}
399-
400-
return IndexResult{std::move(out)};
341+
LOG(DFATAL) << "KNN node should not be searched in shard";
342+
return IndexResult{};
401343
}
402344

403345
// Determine node type and call specific search function
@@ -503,24 +445,15 @@ void FieldIndices::CreateIndices(PMR_NS::memory_resource* mr) {
503445
indices_[field_ident] = make_unique<TagIndex>(mr, tparams);
504446
break;
505447
}
506-
case SchemaField::VECTOR: {
507-
unique_ptr<BaseVectorIndex<DocId>> vector_index;
508-
509-
DCHECK(holds_alternative<SchemaField::VectorParams>(field_info.special_params));
510-
const auto& vparams = std::get<SchemaField::VectorParams>(field_info.special_params);
511-
512-
if (vparams.use_hnsw)
513-
vector_index = make_unique<HnswVectorIndex<DocId>>(vparams, mr);
514-
else
515-
vector_index = make_unique<FlatVectorIndex<DocId>>(vparams, mr);
516-
517-
indices_[field_ident] = std::move(vector_index);
518-
break;
519-
}
520448
case SchemaField::GEO: {
521449
indices_[field_ident] = make_unique<GeoIndex>(mr);
522450
break;
523451
}
452+
case SchemaField::VECTOR: {
453+
const auto& vparams = std::get<SchemaField::VectorParams>(field_info.special_params);
454+
indices_[field_ident] = make_unique<ShardNoOpVectorIndex>(vparams);
455+
break;
456+
}
524457
}
525458
}
526459
}
@@ -666,14 +599,21 @@ SearchResult SearchAlgorithm::Search(const FieldIndices* index, size_t cuttoff_l
666599
return bs.Search(*query_, cuttoff_limit);
667600
}
668601

669-
optional<KnnScoreSortOption> SearchAlgorithm::GetKnnScoreSortOption() const {
670-
DCHECK(query_);
671-
672-
// KNN query
673-
if (auto* knn = get_if<AstKnnNode>(query_.get()); knn)
674-
return KnnScoreSortOption{string_view{knn->score_alias}, knn->limit};
602+
bool SearchAlgorithm::IsKnnQuery() const {
603+
return std::holds_alternative<AstKnnNode>(*query_);
604+
}
675605

676-
return nullopt;
606+
std::unique_ptr<AstNode> SearchAlgorithm::GetKnnNode() {
607+
if (auto* knn = get_if<AstKnnNode>(query_.get()); knn) {
608+
// Save knn score sort option
609+
knn_score_sort_option_ = KnnScoreSortOption{string_view{knn->score_alias}, knn->limit};
610+
auto node = std::move(query_);
611+
if (!std::holds_alternative<AstStarNode>(*(knn)->filter))
612+
query_.swap(knn->filter);
613+
return node;
614+
}
615+
LOG(DFATAL) << "Should not reach here";
616+
return nullptr;
677617
}
678618

679619
void SearchAlgorithm::EnableProfiling() {

src/core/search/search.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ namespace dfly::search {
2323

2424
struct AstNode;
2525
struct TextIndex;
26+
struct AstKnnNode;
2627

2728
// Optional FILTER
2829
struct OptionalNumericFilter : public OptionalFilterBase {
@@ -201,14 +202,20 @@ class SearchAlgorithm {
201202
SearchResult Search(const FieldIndices* index,
202203
size_t cuttoff_limit = std::numeric_limits<size_t>::max()) const;
203204

204-
// if enabled, return limit & alias for knn query
205-
std::optional<KnnScoreSortOption> GetKnnScoreSortOption() const;
205+
const std::optional<KnnScoreSortOption>& GetKnnScoreSortOption() const {
206+
return knn_score_sort_option_;
207+
}
208+
209+
bool IsKnnQuery() const;
210+
211+
std::unique_ptr<AstNode> GetKnnNode();
206212

207213
void EnableProfiling();
208214

209215
private:
210216
bool profiling_enabled_ = false;
211217
std::unique_ptr<AstNode> query_;
218+
std::optional<KnnScoreSortOption> knn_score_sort_option_;
212219
};
213220

214221
} // namespace dfly::search

src/server/main_service.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ extern "C" {
5959
#include "server/multi_command_squasher.h"
6060
#include "server/namespaces.h"
6161
#include "server/script_mgr.h"
62+
#include "server/search/global_vector_index.h"
6263
#include "server/search/search_family.h"
6364
#include "server/server_state.h"
6465
#include "server/set_family.h"
@@ -1123,6 +1124,11 @@ void Service::Shutdown() {
11231124

11241125
shard_set->PreShutdown();
11251126
shard_set->Shutdown();
1127+
1128+
#ifdef WITH_SEARCH
1129+
GlobalVectorIndexRegistry::Instance().Reset();
1130+
#endif
1131+
11261132
Transaction::Shutdown();
11271133

11281134
pp_.AwaitFiberOnAll([](ProactorBase* pb) { ServerState::tlocal()->Destroy(); });

src/server/search/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ if (NOT WITH_SEARCH)
44
return()
55
endif()
66

7-
add_library(dfly_search_server aggregator.cc doc_accessors.cc doc_index.cc search_family.cc index_join.cc
7+
add_library(dfly_search_server aggregator.cc doc_accessors.cc doc_index.cc search_family.cc index_join.cc global_vector_index.cc
88
../cluster/coordinator.cc)
99
target_link_libraries(dfly_search_server dfly_transaction dragonfly_lib dfly_facade redis_lib jsonpath TRDP::jsoncons)
1010

0 commit comments

Comments
 (0)