diff --git a/src/core/search/ast_expr.cc b/src/core/search/ast_expr.cc index 4fa111c4c3e6..60c6d3334fae 100644 --- a/src/core/search/ast_expr.cc +++ b/src/core/search/ast_expr.cc @@ -73,6 +73,10 @@ AstKnnNode::AstKnnNode(AstNode&& filter, AstKnnNode&& self) { this->filter = make_unique(std::move(filter)); } +bool AstKnnNode::Filter() const { + return filter == nullptr; +} + } // namespace dfly::search namespace std { diff --git a/src/core/search/ast_expr.h b/src/core/search/ast_expr.h index e62551dc0b7a..c305b24cc667 100644 --- a/src/core/search/ast_expr.h +++ b/src/core/search/ast_expr.h @@ -114,6 +114,8 @@ struct AstKnnNode { OwnedFtVector vec; std::string score_alias; std::optional ef_runtime; + + bool Filter() const; }; using NodeVariants = diff --git a/src/core/search/base.h b/src/core/search/base.h index 492b1435c61e..7358c7349e72 100644 --- a/src/core/search/base.h +++ b/src/core/search/base.h @@ -16,10 +16,24 @@ #include "absl/container/flat_hash_set.h" #include "base/pmr/memory_resource.h" #include "core/string_map.h" +#include "server/tx_base.h" namespace dfly::search { using DocId = uint32_t; +using GlobalDocId = uint64_t; + +inline GlobalDocId CreateGlobalDocId(ShardId shard_id, DocId local_doc_id) { + return ((uint64_t)shard_id << 32) | local_doc_id; +} + +inline ShardId GlobalDocIdShardId(GlobalDocId id) { + return (id >> 32); +} + +inline search::DocId GlobalDocIdLocalId(GlobalDocId id) { + return (id)&0xFFFF; +} enum class VectorSimilarity { L2, IP, COSINE }; @@ -79,16 +93,16 @@ struct DocumentAccessor { // // Queries should be done directly on subclasses with their distinc // query functions. All results for all index types should be sorted. -struct BaseIndex { +template struct BaseIndex { virtual ~BaseIndex() = default; // Returns true if the document was added / indexed - virtual bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) = 0; - virtual void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) = 0; + virtual bool Add(T id, const DocumentAccessor& doc, std::string_view field) = 0; + virtual void Remove(T id, const DocumentAccessor& doc, std::string_view field) = 0; // Returns documents that have non-null values for this field (used for @field:* queries) // Result must be sorted - virtual std::vector GetAllDocsWithNonNullValues() const = 0; + virtual std::vector GetAllDocsWithNonNullValues() const = 0; /* Called at the end of indexes rebuilding after all initial Add calls are done. Some indices may need to finalize internal structures. See RangeTree for example. */ @@ -97,10 +111,9 @@ struct BaseIndex { }; // Base class for type-specific sorting indices. -struct BaseSortIndex : BaseIndex { - virtual SortableValue Lookup(DocId doc) const = 0; - virtual std::vector Sort(std::vector* ids, size_t limit, - bool desc) const = 0; +template struct BaseSortIndex : BaseIndex { + virtual SortableValue Lookup(T doc) const = 0; + virtual std::vector Sort(std::vector* ids, size_t limit, bool desc) const = 0; }; /* Used in iterators of inverse indices. diff --git a/src/core/search/indices.cc b/src/core/search/indices.cc index 24fff112b6a5..495ddecd5d00 100644 --- a/src/core/search/indices.cc +++ b/src/core/search/indices.cc @@ -11,6 +11,11 @@ #include #include +#include + +#include "core/search/base.h" +#include "core/search/vector_utils.h" +#include "util/fibers/synchronization.h" #define UNI_ALGO_DISABLE_NFKC_NFKD @@ -128,6 +133,16 @@ std::optional GetGeoPoint(const DocumentAccessor& doc, string_v return GeoIndex::point{lon, lat}; } +template vector> QueueToVec(Q queue) { + vector> out(queue.size()); + size_t idx = out.size(); + while (!queue.empty()) { + out[--idx] = queue.top(); + queue.pop(); + } + return out; +} + }; // namespace class RangeTreeAdapter : public NumericIndex::RangeTreeBase { @@ -473,14 +488,12 @@ absl::flat_hash_set TagIndex::Tokenize(std::string_view value) cons return NormalizeTags(value, case_sensitive_, separator_); } -BaseVectorIndex::BaseVectorIndex(size_t dim, VectorSimilarity sim) : dim_{dim}, sim_{sim} { +template +BaseVectorIndex::BaseVectorIndex(size_t dim, VectorSimilarity sim) : dim_{dim}, sim_{sim} { } -std::pair BaseVectorIndex::Info() const { - return {dim_, sim_}; -} - -bool BaseVectorIndex::Add(DocId id, const DocumentAccessor& doc, std::string_view field) { +template +bool BaseVectorIndex::Add(T id, const DocumentAccessor& doc, std::string_view field) { auto vector = doc.GetVector(field); if (!vector) return false; @@ -494,63 +507,73 @@ bool BaseVectorIndex::Add(DocId id, const DocumentAccessor& doc, std::string_vie return true; } -FlatVectorIndex::FlatVectorIndex(const SchemaField::VectorParams& params, +ShardNoOpVectorIndex::ShardNoOpVectorIndex(const SchemaField::VectorParams& params) + : BaseVectorIndex{params.dim, params.sim} { +} + +FlatVectorIndex::FlatVectorIndex(const SchemaField::VectorParams& params, ShardId shard_set_size, PMR_NS::memory_resource* mr) - : BaseVectorIndex{params.dim, params.sim}, entries_{mr} { + : BaseVectorIndex{params.dim, params.sim}, + entries_{mr}, + shard_vector_locks_(shard_set_size) { DCHECK(!params.use_hnsw); - entries_.reserve(params.capacity * params.dim); + entries_.resize(shard_set_size); + for (size_t i = 0; i < shard_set_size; i++) { + entries_[i].reserve(params.capacity * params.dim); + } } -void FlatVectorIndex::AddVector(DocId id, const VectorPtr& vector) { - DCHECK_LE(id * dim_, entries_.size()); - if (id * dim_ == entries_.size()) - entries_.resize((id + 1) * dim_); - - // TODO: Let get vector write to buf itself +void FlatVectorIndex::AddVector(GlobalDocId id, + const typename BaseVectorIndex::VectorPtr& vector) { + auto shard_id = search::GlobalDocIdShardId(id); + auto shard_doc_id = search::GlobalDocIdLocalId(id); + DCHECK_LE(shard_doc_id * BaseVectorIndex::dim_, entries_[shard_id].size()); + if (shard_doc_id * BaseVectorIndex::dim_ == entries_[shard_id].size()) { + unique_lock lock{shard_vector_locks_[shard_id]}; + entries_[shard_id].resize((shard_doc_id + 1) * BaseVectorIndex::dim_); + } if (vector) { - memcpy(&entries_[id * dim_], vector.get(), dim_ * sizeof(float)); + memcpy(&entries_[shard_id][shard_doc_id * BaseVectorIndex::dim_], vector.get(), + BaseVectorIndex::dim_ * sizeof(float)); } } -void FlatVectorIndex::Remove(DocId id, const DocumentAccessor& doc, string_view field) { +void FlatVectorIndex::Remove(GlobalDocId id, const DocumentAccessor& doc, string_view field) { // noop } -const float* FlatVectorIndex::Get(DocId doc) const { - return &entries_[doc * dim_]; -} - -std::vector FlatVectorIndex::GetAllDocsWithNonNullValues() const { - std::vector result; +std::vector> FlatVectorIndex::Knn(float* target, size_t k) const { + std::priority_queue> queue; - size_t num_vectors = entries_.size() / dim_; - result.reserve(num_vectors); + for (size_t shard_id = 0; shard_id < entries_.size(); shard_id++) { + shared_lock lock{shard_vector_locks_[shard_id]}; + size_t num_vectors = entries_[shard_id].size() / BaseVectorIndex::dim_; + for (GlobalDocId id = 0; id < num_vectors; ++id) { + const float* vec = &entries_[shard_id][id * dim_]; + float dist = VectorDistance(target, vec, dim_, sim_); + queue.emplace(dist, CreateGlobalDocId(shard_id, id)); + } + } - for (DocId id = 0; id < num_vectors; ++id) { - // Check if the vector is not zero (all elements are 0) - // TODO: Valid vector can contain 0s, we should use a better approach - const float* vec = Get(id); - bool is_zero_vector = true; + return QueueToVec(queue); +} - // TODO: Consider don't use check for zero vector - for (size_t i = 0; i < dim_; ++i) { - if (vec[i] != 0.0f) { // TODO: Consider using a threshold for float comparison - is_zero_vector = false; - break; - } - } +std::vector> FlatVectorIndex::Knn( + float* target, size_t k, const std::vector& allowed_docs) const { + std::priority_queue> queue; - if (!is_zero_vector) { - result.push_back(id); + for (size_t shard_id = 0; shard_id < allowed_docs.size(); shard_id++) { + shared_lock lock{shard_vector_locks_[shard_id]}; + for (auto& shard_doc_id : allowed_docs[shard_id]) { + const float* vec = &entries_[shard_id][shard_doc_id * dim_]; + float dist = VectorDistance(target, vec, dim_, sim_); + queue.emplace(dist, CreateGlobalDocId(shard_id, shard_doc_id)); } } - - // Result is already sorted by id, no need to sort again - // Also it has no duplicates - return result; + return QueueToVec(queue); } -struct HnswlibAdapter { +template struct HnswlibAdapter { // Default setting of hnswlib/hnswalg constexpr static size_t kDefaultEfRuntime = 10; @@ -560,34 +583,45 @@ struct HnswlibAdapter { 100 /* seed*/} { } - void Add(const float* data, DocId id) { - if (world_.cur_element_count + 1 >= world_.max_elements_) - world_.resizeIndex(world_.cur_element_count * 2); - world_.addPoint(data, id); + void Add(const float* data, T id) { + while (true) { + try { + absl::ReaderMutexLock lock(&resize_mutex_); + world_.addPoint(data, id); + return; + } catch (const std::exception& e) { + std::string error_msg = e.what(); + if (absl::StrContains(error_msg, "The number of elements exceeds the specified limit")) { + ResizeIfFull(); + continue; + } + throw e; + } + } } - void Remove(DocId id) { + void Remove(T id) { try { world_.markDelete(id); } catch (const std::exception& e) { } } - vector> Knn(float* target, size_t k, std::optional ef) { + vector> Knn(float* target, size_t k, std::optional ef) { world_.setEf(ef.value_or(kDefaultEfRuntime)); return QueueToVec(world_.searchKnn(target, k)); } - vector> Knn(float* target, size_t k, std::optional ef, - const vector& allowed) { + vector> Knn(float* target, size_t k, std::optional ef, + const vector& allowed) { struct BinsearchFilter : hnswlib::BaseFilterFunctor { virtual bool operator()(hnswlib::labeltype id) { return binary_search(allowed->begin(), allowed->end(), id); } - BinsearchFilter(const vector* allowed) : allowed{allowed} { + BinsearchFilter(const vector* allowed) : allowed{allowed} { } - const vector* allowed; + const vector* allowed; }; world_.setEf(ef.value_or(kDefaultEfRuntime)); @@ -609,46 +643,61 @@ struct HnswlibAdapter { return visit([](auto& space) -> hnswlib::SpaceInterface* { return &space; }, space_); } - template static vector> QueueToVec(Q queue) { - vector> out(queue.size()); - size_t idx = out.size(); - while (!queue.empty()) { - out[--idx] = queue.top(); - queue.pop(); + void ResizeIfFull() { + { + absl::ReaderMutexLock lock(&resize_mutex_); + if (world_.getCurrentElementCount() < world_.getMaxElements() || + (world_.allow_replace_deleted_ && world_.getDeletedCount() > 0)) { + return; + } + } + try { + absl::WriterMutexLock lock(&resize_mutex_); + if (world_.getCurrentElementCount() == world_.getMaxElements() && + (!world_.allow_replace_deleted_ || world_.getDeletedCount() == 0)) { + auto max_elements = world_.getMaxElements(); + world_.resizeIndex(max_elements * 2); + LOG(INFO) << "Resizing HNSW Index, current size: " << max_elements + << ", expand by: " << max_elements * 2; + } + } catch (const std::exception& e) { + throw e; } - return out; } SpaceUnion space_; hnswlib::HierarchicalNSW world_; + absl::Mutex resize_mutex_; }; HnswVectorIndex::HnswVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource*) - : BaseVectorIndex{params.dim, params.sim}, adapter_{make_unique(params)} { + : BaseVectorIndex{params.dim, params.sim}, + adapter_{make_unique>(params)} { DCHECK(params.use_hnsw); // TODO: Patch hnsw to use MR } - HnswVectorIndex::~HnswVectorIndex() { } -void HnswVectorIndex::AddVector(DocId id, const VectorPtr& vector) { +void HnswVectorIndex::AddVector(GlobalDocId id, + const typename BaseVectorIndex::VectorPtr& vector) { if (vector) { adapter_->Add(vector.get(), id); } } -std::vector> HnswVectorIndex::Knn(float* target, size_t k, - std::optional ef) const { +std::vector> HnswVectorIndex::Knn(float* target, size_t k, + std::optional ef) const { return adapter_->Knn(target, k, ef); } -std::vector> HnswVectorIndex::Knn(float* target, size_t k, - std::optional ef, - const std::vector& allowed) const { + +std::vector> HnswVectorIndex::Knn( + float* target, size_t k, std::optional ef, + const std::vector& allowed) const { return adapter_->Knn(target, k, ef, allowed); } -void HnswVectorIndex::Remove(DocId id, const DocumentAccessor& doc, string_view field) { +void HnswVectorIndex::Remove(GlobalDocId id, const DocumentAccessor& doc, string_view field) { adapter_->Remove(id); } diff --git a/src/core/search/indices.h b/src/core/search/indices.h index 2f9e54e826d6..5311f1eb26c8 100644 --- a/src/core/search/indices.h +++ b/src/core/search/indices.h @@ -8,6 +8,9 @@ #include #include +// #include "server/search/global_vector_index.h" +#include "util/fibers/synchronization.h" + // Wrong warning reported when geometry.hpp is loaded #ifndef __clang__ #pragma GCC diagnostic push @@ -39,7 +42,7 @@ namespace dfly::search { // Index for integer fields. // Range bounds are queried in logarithmic time, iteration is constant. -struct NumericIndex : public BaseIndex { +struct NumericIndex : public BaseIndex { // Temporary base class for range tree. // It is used to use two different range trees depending on the flag use_range_tree. // If the flag is true, RangeTree is used, otherwise a simple implementation with btree_set. @@ -76,7 +79,7 @@ struct NumericIndex : public BaseIndex { }; // Base index for string based indices. -template struct BaseStringIndex : public BaseIndex { +template struct BaseStringIndex : public BaseIndex { using Container = BlockList; using VecOrPtr = std::variant, const Container*>; @@ -157,65 +160,100 @@ struct TagIndex : public BaseStringIndex> { char separator_; }; -struct BaseVectorIndex : public BaseIndex { - std::pair Info() const; +template struct BaseVectorIndex : public BaseIndex { + std::pair Info() const { + return {dim_, sim_}; + } - bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override final; + bool Add(T id, const DocumentAccessor& doc, std::string_view field) override final; protected: BaseVectorIndex(size_t dim, VectorSimilarity sim); using VectorPtr = decltype(std::declval().first); - virtual void AddVector(DocId id, const VectorPtr& vector) = 0; + virtual void AddVector(T id, const VectorPtr& vector) = 0; size_t dim_; VectorSimilarity sim_; }; +// ShardNoOpVectorIndex is used as placeholder as vector index in each shard. It doesn't implement +// any functionality so adding documents will not have any effect on it. It is used to support +// as filter when adding fields. +struct ShardNoOpVectorIndex : public BaseVectorIndex { + explicit ShardNoOpVectorIndex(const SchemaField::VectorParams& params); + + void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override { + // noop + } + + // Return all documents that have vectors in this index + std::vector GetAllDocsWithNonNullValues() const override { + return {}; + } + + protected: + using BaseVectorIndex::dim_; + void AddVector(DocId id, const typename BaseVectorIndex::VectorPtr& vector) override { + // noop + } +}; + // Index for vector fields. // Only supports lookup by id. -struct FlatVectorIndex : public BaseVectorIndex { - FlatVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource* mr); +struct FlatVectorIndex : public BaseVectorIndex { + FlatVectorIndex(const SchemaField::VectorParams& params, ShardId shard_set_size, + PMR_NS::memory_resource* mr); - void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override; + void Remove(GlobalDocId id, const DocumentAccessor& doc, std::string_view field) override; - const float* Get(DocId doc) const; + std::vector> Knn(float* target, size_t k) const; - // Return all documents that have vectors in this index - std::vector GetAllDocsWithNonNullValues() const override; + using FilterShardDocs = std::vector; + + std::vector> Knn(float* target, size_t k, + const std::vector& allowed) const; + + std::vector GetAllDocsWithNonNullValues() const override { + return std::vector{}; + } protected: - void AddVector(DocId id, const VectorPtr& vector) override; + using BaseVectorIndex::dim_; + void AddVector(GlobalDocId id, + const typename BaseVectorIndex::VectorPtr& vector) override; private: - PMR_NS::vector entries_; + PMR_NS::vector> entries_; + mutable std::vector shard_vector_locks_; }; -struct HnswlibAdapter; - -struct HnswVectorIndex : public BaseVectorIndex { +template struct HnswlibAdapter; +struct HnswVectorIndex : public BaseVectorIndex { HnswVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource* mr); ~HnswVectorIndex(); - void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override; + void Remove(GlobalDocId id, const DocumentAccessor& doc, std::string_view field) override; - std::vector> Knn(float* target, size_t k, std::optional ef) const; - std::vector> Knn(float* target, size_t k, std::optional ef, - const std::vector& allowed) const; + std::vector> Knn(float* target, size_t k, + std::optional ef) const; + std::vector> Knn(float* target, size_t k, std::optional ef, + const std::vector& allowed) const; // TODO: Implement if needed - std::vector GetAllDocsWithNonNullValues() const override { - return std::vector{}; + std::vector GetAllDocsWithNonNullValues() const override { + return std::vector{}; } protected: - void AddVector(DocId id, const VectorPtr& vector) override; + void AddVector(GlobalDocId id, + const typename BaseVectorIndex::VectorPtr& vector) override; private: - std::unique_ptr adapter_; + std::unique_ptr> adapter_; }; -struct GeoIndex : public BaseIndex { +struct GeoIndex : public BaseIndex { using point = boost::geometry::model::point>; diff --git a/src/core/search/search.cc b/src/core/search/search.cc index 512c75b54160..3cc1b650d076 100644 --- a/src/core/search/search.cc +++ b/src/core/search/search.cc @@ -122,7 +122,7 @@ struct BasicSearch { profile_builder_ = ProfileBuilder{}; } - BaseIndex* GetBaseIndex(string_view field) { + BaseIndex* GetBaseIndex(string_view field) { auto index = indices_->GetIndex(field); if (!index) { error_ = absl::StrCat("Invalid field: ", field); @@ -133,7 +133,7 @@ struct BasicSearch { // Get casted sub index by field template T* GetIndex(string_view field) { - static_assert(is_base_of_v); + static_assert(is_base_of_v, T>); auto base_index = GetBaseIndex(field); if (!base_index) { @@ -149,7 +149,7 @@ struct BasicSearch { return casted_ptr; } - BaseSortIndex* GetSortIndex(string_view field) { + BaseSortIndex* GetSortIndex(string_view field) { auto index = indices_->GetSortIndex(field); if (!index) { error_ = absl::StrCat("Invalid sort field: ", field); @@ -210,13 +210,13 @@ struct BasicSearch { IndexResult Search(const AstStarFieldNode& node, string_view active_field) { // Try to get a sort index first, as `@field:*` might imply wanting sortable behavior - BaseSortIndex* sort_index = indices_->GetSortIndex(active_field); + BaseSortIndex* sort_index = indices_->GetSortIndex(active_field); if (sort_index) { return IndexResult{sort_index->GetAllDocsWithNonNullValues()}; } // If sort index doesn't exist try regular index - BaseIndex* base_index = GetBaseIndex(active_field); + BaseIndex* base_index = GetBaseIndex(active_field); return base_index ? IndexResult{base_index->GetAllDocsWithNonNullValues()} : IndexResult{}; } @@ -336,66 +336,10 @@ struct BasicSearch { return UnifyResults(GetSubResults(node.tags, mapping), LogicOp::OR); } - void SearchKnnFlat(FlatVectorIndex* vec_index, const AstKnnNode& knn, IndexResult&& sub_results) { - knn_distances_.reserve(sub_results.ApproximateSize()); - auto cb = [&](auto* set) { - auto [dim, sim] = vec_index->Info(); - for (DocId matched_doc : *set) { - float dist = VectorDistance(knn.vec.first.get(), vec_index->Get(matched_doc), dim, sim); - knn_distances_.emplace_back(dist, matched_doc); - } - }; - visit(cb, sub_results.Borrowed()); - - size_t prefix_size = min(knn.limit, knn_distances_.size()); - partial_sort(knn_distances_.begin(), knn_distances_.begin() + prefix_size, - knn_distances_.end()); - knn_distances_.resize(prefix_size); - } - - void SearchKnnHnsw(HnswVectorIndex* vec_index, const AstKnnNode& knn, IndexResult&& sub_results) { - if (indices_->GetAllDocs().size() == sub_results.ApproximateSize()) // TODO: remove approx size - knn_distances_ = vec_index->Knn(knn.vec.first.get(), knn.limit, knn.ef_runtime); - else - knn_distances_ = - vec_index->Knn(knn.vec.first.get(), knn.limit, knn.ef_runtime, sub_results.Take().first); - } - // [KNN limit @field vec]: Compute distance from `vec` to all vectors keep closest `limit` IndexResult Search(const AstKnnNode& knn, string_view active_field) { - DCHECK(active_field.empty()); - auto sub_results = SearchGeneric(*knn.filter, active_field); - - auto* vec_index = GetIndex(knn.field); - if (!vec_index) - return IndexResult{}; - - // If vector dimension is 0, treat as placeholder/invalid - return empty results - // This allows tests to use dummy vector values like "" - if (knn.vec.second == 0) - return IndexResult{}; - - if (auto [dim, _] = vec_index->Info(); dim != knn.vec.second) { - error_ = - absl::StrCat("Wrong vector index dimensions, got: ", knn.vec.second, ", expected: ", dim); - return IndexResult{}; - } - - knn_scores_.clear(); - if (auto hnsw_index = dynamic_cast(vec_index); hnsw_index) - SearchKnnHnsw(hnsw_index, knn, std::move(sub_results)); - else - SearchKnnFlat(dynamic_cast(vec_index), knn, std::move(sub_results)); - - vector out(knn_distances_.size()); - knn_scores_.reserve(knn_distances_.size()); - - for (size_t i = 0; i < knn_distances_.size(); i++) { - knn_scores_.emplace_back(knn_distances_[i].second, knn_distances_[i].first); - out[i] = knn_distances_[i].second; - } - - return IndexResult{std::move(out)}; + LOG(DFATAL) << "KNN node should not be searched in shard"; + return IndexResult{}; } // Determine node type and call specific search function @@ -501,24 +445,15 @@ void FieldIndices::CreateIndices(PMR_NS::memory_resource* mr) { indices_[field_ident] = make_unique(mr, tparams); break; } - case SchemaField::VECTOR: { - unique_ptr vector_index; - - DCHECK(holds_alternative(field_info.special_params)); - const auto& vparams = std::get(field_info.special_params); - - if (vparams.use_hnsw) - vector_index = make_unique(vparams, mr); - else - vector_index = make_unique(vparams, mr); - - indices_[field_ident] = std::move(vector_index); - break; - } case SchemaField::GEO: { indices_[field_ident] = make_unique(mr); break; } + case SchemaField::VECTOR: { + const auto& vparams = std::get(field_info.special_params); + indices_[field_ident] = make_unique(vparams); + break; + } } } } @@ -546,7 +481,7 @@ void FieldIndices::CreateSortIndices(PMR_NS::memory_resource* mr) { bool FieldIndices::Add(DocId doc, const DocumentAccessor& access) { bool was_added = true; - std::vector> successfully_added_indices; + std::vector*>> successfully_added_indices; successfully_added_indices.reserve(indices_.size() + sort_indices_.size()); auto try_add = [&](const auto& indices_container) { @@ -588,12 +523,12 @@ void FieldIndices::Remove(DocId doc, const DocumentAccessor& access) { all_ids_.erase(it); } -BaseIndex* FieldIndices::GetIndex(string_view field) const { +BaseIndex* FieldIndices::GetIndex(string_view field) const { auto it = indices_.find(schema_.LookupAlias(field)); return it != indices_.end() ? it->second.get() : nullptr; } -BaseSortIndex* FieldIndices::GetSortIndex(string_view field) const { +BaseSortIndex* FieldIndices::GetSortIndex(string_view field) const { auto it = sort_indices_.find(schema_.LookupAlias(field)); return it != sort_indices_.end() ? it->second.get() : nullptr; } @@ -664,14 +599,21 @@ SearchResult SearchAlgorithm::Search(const FieldIndices* index, size_t cuttoff_l return bs.Search(*query_, cuttoff_limit); } -optional SearchAlgorithm::GetKnnScoreSortOption() const { - DCHECK(query_); - - // KNN query - if (auto* knn = get_if(query_.get()); knn) - return KnnScoreSortOption{string_view{knn->score_alias}, knn->limit}; +bool SearchAlgorithm::IsKnnQuery() const { + return std::holds_alternative(*query_); +} - return nullopt; +std::unique_ptr SearchAlgorithm::GetKnnNode() { + if (auto* knn = get_if(query_.get()); knn) { + // Save knn score sort option + knn_score_sort_option_ = KnnScoreSortOption{string_view{knn->score_alias}, knn->limit}; + auto node = std::move(query_); + if (!std::holds_alternative(*(knn)->filter)) + query_.swap(knn->filter); + return node; + } + LOG(DFATAL) << "Should not reach here"; + return nullptr; } void SearchAlgorithm::EnableProfiling() { diff --git a/src/core/search/search.h b/src/core/search/search.h index 96f2b4271e04..8cc18b363c45 100644 --- a/src/core/search/search.h +++ b/src/core/search/search.h @@ -23,6 +23,7 @@ namespace dfly::search { struct AstNode; struct TextIndex; +struct AstKnnNode; // Optional FILTER struct OptionalNumericFilter : public OptionalFilterBase { @@ -129,8 +130,8 @@ class FieldIndices { bool Add(DocId doc, const DocumentAccessor& access); void Remove(DocId doc, const DocumentAccessor& access); - BaseIndex* GetIndex(std::string_view field) const; - BaseSortIndex* GetSortIndex(std::string_view field) const; + BaseIndex* GetIndex(std::string_view field) const; + BaseSortIndex* GetSortIndex(std::string_view field) const; std::vector GetAllTextIndices() const; const std::vector& GetAllDocs() const; @@ -149,8 +150,8 @@ class FieldIndices { const Schema& schema_; const IndicesOptions& options_; std::vector all_ids_; - absl::flat_hash_map> indices_; - absl::flat_hash_map> sort_indices_; + absl::flat_hash_map>> indices_; + absl::flat_hash_map>> sort_indices_; const Synonyms* synonyms_; }; @@ -201,14 +202,20 @@ class SearchAlgorithm { SearchResult Search(const FieldIndices* index, size_t cuttoff_limit = std::numeric_limits::max()) const; - // if enabled, return limit & alias for knn query - std::optional GetKnnScoreSortOption() const; + const std::optional& GetKnnScoreSortOption() const { + return knn_score_sort_option_; + } + + bool IsKnnQuery() const; + + std::unique_ptr GetKnnNode(); void EnableProfiling(); private: bool profiling_enabled_ = false; std::unique_ptr query_; + std::optional knn_score_sort_option_; }; } // namespace dfly::search diff --git a/src/core/search/sort_indices.h b/src/core/search/sort_indices.h index 5f856157126e..57d55d704a2a 100644 --- a/src/core/search/sort_indices.h +++ b/src/core/search/sort_indices.h @@ -20,7 +20,7 @@ namespace dfly::search { -template struct SimpleValueSortIndex : public BaseSortIndex { +template struct SimpleValueSortIndex : public BaseSortIndex { protected: struct ParsedSortValue { bool HasValue() const; diff --git a/src/external_libs.cmake b/src/external_libs.cmake index 7ddd09062e53..4ba6166a39e7 100644 --- a/src/external_libs.cmake +++ b/src/external_libs.cmake @@ -134,7 +134,9 @@ if (WITH_SEARCH) add_third_party( hnswlib - URL https://github.com/nmslib/hnswlib/archive/refs/tags/v0.8.0.tar.gz + GIT_REPOSITORY https://github.com/dragonflydb/hnswlib.git + # HEAD of dragonfly branch + GIT_TAG d07dd1da2bf48b85d2f03b8396193ad7120f75c2 BUILD_COMMAND echo SKIP INSTALL_COMMAND cp -R /hnswlib ${THIRD_PARTY_LIB_DIR}/hnswlib/include/ diff --git a/src/server/main_service.cc b/src/server/main_service.cc index f9c44c938a30..99f0fa3e2f23 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -59,6 +59,7 @@ extern "C" { #include "server/multi_command_squasher.h" #include "server/namespaces.h" #include "server/script_mgr.h" +#include "server/search/global_vector_index.h" #include "server/search/search_family.h" #include "server/server_state.h" #include "server/set_family.h" @@ -1123,6 +1124,11 @@ void Service::Shutdown() { shard_set->PreShutdown(); shard_set->Shutdown(); + +#ifdef WITH_SEARCH + GlobalVectorIndexRegistry::Instance().Reset(); +#endif + Transaction::Shutdown(); pp_.AwaitFiberOnAll([](ProactorBase* pb) { ServerState::tlocal()->Destroy(); }); diff --git a/src/server/search/CMakeLists.txt b/src/server/search/CMakeLists.txt index 81160d2d69f8..31ee0a90191e 100644 --- a/src/server/search/CMakeLists.txt +++ b/src/server/search/CMakeLists.txt @@ -4,7 +4,7 @@ if (NOT WITH_SEARCH) return() endif() -add_library(dfly_search_server aggregator.cc doc_accessors.cc doc_index.cc search_family.cc index_join.cc +add_library(dfly_search_server aggregator.cc doc_accessors.cc doc_index.cc search_family.cc index_join.cc global_vector_index.cc ../cluster/coordinator.cc) target_link_libraries(dfly_search_server dfly_transaction dragonfly_lib dfly_facade redis_lib jsonpath TRDP::jsoncons) diff --git a/src/server/search/doc_index.cc b/src/server/search/doc_index.cc index f8a744e16a9f..f60f8ffff272 100644 --- a/src/server/search/doc_index.cc +++ b/src/server/search/doc_index.cc @@ -16,6 +16,7 @@ #include "server/engine_shard_set.h" #include "server/family_utils.h" #include "server/search/doc_accessors.h" +#include "server/search/global_vector_index.h" #include "server/server_state.h" namespace dfly { @@ -219,17 +220,15 @@ ShardDocIndex::DocId ShardDocIndex::DocKeyIndex::Add(string_view key) { return id; } -std::optional ShardDocIndex::DocKeyIndex::Remove(string_view key) { - auto it = ids_.extract(key); - if (!it) { - return std::nullopt; - } +std::optional ShardDocIndex::DocKeyIndex::Find(string_view key) const { + auto it = ids_.find(key); + return it != ids_.end() ? std::make_optional(it->second) : std::nullopt; +} - const DocId id = it.mapped(); +void ShardDocIndex::DocKeyIndex::Remove(DocId id) { + ids_.extract(keys_[id]); keys_[id] = ""; free_ids_.push_back(id); - - return id; } string_view ShardDocIndex::DocKeyIndex::Get(DocId id) const { @@ -274,7 +273,7 @@ void ShardDocIndex::Rebuild(const OpArgs& op_args, PMR_NS::memory_resource* mr) auto cb = [this](string_view key, const BaseAccessor& doc) { DocId id = key_index_.Add(key); if (!indices_->Add(id, doc)) { - key_index_.Remove(key); + key_index_.Remove(id); } }; @@ -333,36 +332,100 @@ void ShardDocIndex::RebuildForGroup(const OpArgs& op_args, const std::string_vie update_indices(false); } -void ShardDocIndex::AddDoc(string_view key, const DbContext& db_cntx, const PrimeValue& pv) { +std::optional ShardDocIndex::GetDocId(std::string_view key, + const DbContext& db_cntx) { if (!indices_) - return; + return std::nullopt; + + // Only handle documents from database 0 + if (db_cntx.db_index != 0) + return std::nullopt; + + return key_index_.Find(key); +} + +std::optional ShardDocIndex::AddDoc(string_view key, const DbContext& db_cntx, + const PrimeValue& pv) { + if (!indices_) + return std::nullopt; // Only index documents from database 0 if (db_cntx.db_index != 0) - return; + return std::nullopt; + ; auto accessor = GetAccessor(db_cntx, pv); DocId id = key_index_.Add(key); if (!indices_->Add(id, *accessor)) { - key_index_.Remove(key); + key_index_.Remove(id); + return std::nullopt; } + + return id; } -void ShardDocIndex::RemoveDoc(string_view key, const DbContext& db_cntx, const PrimeValue& pv) { - if (!indices_) - return; +void ShardDocIndex::RemoveDoc(DocId id, const DbContext& db_cntx, const PrimeValue& pv) { + auto accessor = GetAccessor(db_cntx, pv); + key_index_.Remove(id); + indices_->Remove(id, *accessor); +} - // Only handle documents from database 0 - if (db_cntx.db_index != 0) - return; +void ShardDocIndex::AddDocToGlobalVectorIndex(std::string_view index_name, + ShardDocIndex::DocId doc_id, const DbContext& db_cntx, + const PrimeValue& pv) { + auto accessor = GetAccessor(db_cntx, pv); + + GlobalDocId global_id = search::CreateGlobalDocId(EngineShard::tlocal()->shard_id(), doc_id); + for (const auto& [field_ident, field_info] : base_->schema.fields) { + if (field_info.type == search::SchemaField::VECTOR && + !(field_info.flags & search::SchemaField::NOINDEX)) { + GlobalVectorIndexRegistry::Instance() + .GetVectorIndex(index_name, field_info.short_name) + ->Add(global_id, *accessor, field_ident); + } + } +} + +void ShardDocIndex::RemoveDocFromGlobalVectorIndex(std::string_view index_name, + ShardDocIndex::DocId doc_id, + const DbContext& db_cntx, const PrimeValue& pv) { auto accessor = GetAccessor(db_cntx, pv); - auto id = key_index_.Remove(key); - if (id) { - indices_->Remove(id.value(), *accessor); + GlobalDocId global_id = search::CreateGlobalDocId(EngineShard::tlocal()->shard_id(), doc_id); + + for (const auto& [field_ident, field_info] : base_->schema.fields) { + if (field_info.type == search::SchemaField::VECTOR && + !(field_info.flags & search::SchemaField::NOINDEX)) { + if (auto global_index = GlobalVectorIndexRegistry::Instance().GetVectorIndex( + index_name, field_info.short_name)) { + global_index->Remove(global_id, *accessor, field_ident); + } + } } } +void ShardDocIndex::RebuildGlobalVectorIndices(std::string_view index_name, const OpArgs& op_args) { + auto cb = [this, index_name](string_view key, const BaseAccessor& doc) { + auto local_id = key_index_.Find(key); + + if (!local_id) + return; + + GlobalDocId global_id = search::CreateGlobalDocId(EngineShard::tlocal()->shard_id(), *local_id); + + for (const auto& [field_ident, field_info] : base_->schema.fields) { + if (field_info.type == search::SchemaField::VECTOR && + !(field_info.flags & search::SchemaField::NOINDEX)) { + GlobalVectorIndexRegistry::Instance() + .GetVectorIndex(index_name, field_info.short_name) + ->Add(global_id, doc, field_ident); + } + } + }; + + TraverseAllMatching(*base_, op_args, cb); +} + bool ShardDocIndex::Matches(string_view key, unsigned obj_code) const { return base_->Matches(key, obj_code); } @@ -482,7 +545,7 @@ SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& pa // Don't load entry if we need only its key. Ignore expiration. if (params.IdsOnly()) { string_view key = key_index_.Get(result.ids[i]); - out.push_back({string{key}, {}, knn_score, sort_score}); + out.push_back({result.ids[i], string{key}, {}, knn_score, sort_score}); continue; } @@ -501,7 +564,7 @@ SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& pa auto more_fields = accessor->Serialize(base_->schema, return_fields); fields.insert(make_move_iterator(more_fields.begin()), make_move_iterator(more_fields.end())); - out.push_back({string{key}, std::move(fields), knn_score, sort_score}); + out.push_back({result.ids[i], string{key}, std::move(fields), knn_score, sort_score}); } return {result.total - expired_count, std::move(out), std::move(result.profile)}; @@ -635,7 +698,7 @@ DocIndexInfo ShardDocIndex::GetInfo() const { } io::Result ShardDocIndex::GetTagVals(string_view field) const { - search::BaseIndex* base_index = indices_->GetIndex(field); + search::BaseIndex* base_index = indices_->GetIndex(field); if (base_index == nullptr) { return make_unexpected(ErrorReply{"-No such field"}); } @@ -698,8 +761,10 @@ void ShardDocIndices::DropIndexCache(const dfly::ShardDocIndex& shard_doc_index) } void ShardDocIndices::RebuildAllIndices(const OpArgs& op_args) { - for (auto& [_, ptr] : indices_) + for (auto& [index_name, ptr] : indices_) { ptr->Rebuild(op_args, &local_mr_); + ptr->RebuildGlobalVectorIndices(index_name, op_args); + } } vector ShardDocIndices::GetIndexNames() const { @@ -712,17 +777,26 @@ vector ShardDocIndices::GetIndexNames() const { void ShardDocIndices::AddDoc(string_view key, const DbContext& db_cntx, const PrimeValue& pv) { DCHECK(IsIndexedKeyType(pv)); - for (auto& [_, index] : indices_) { - if (index->Matches(key, pv.ObjType())) - index->AddDoc(key, db_cntx, pv); + for (auto& [index_name, index] : indices_) { + if (index->Matches(key, pv.ObjType())) { + std::optional doc_id = index->AddDoc(key, db_cntx, pv); + if (doc_id) { + index->AddDocToGlobalVectorIndex(index_name, *doc_id, db_cntx, pv); + } + } } } void ShardDocIndices::RemoveDoc(string_view key, const DbContext& db_cntx, const PrimeValue& pv) { DCHECK(IsIndexedKeyType(pv)); - for (auto& [_, index] : indices_) { - if (index->Matches(key, pv.ObjType())) - index->RemoveDoc(key, db_cntx, pv); + for (auto& [index_name, index] : indices_) { + if (index->Matches(key, pv.ObjType())) { + std::optional doc_id = index->GetDocId(key, db_cntx); + if (doc_id) { + index->RemoveDocFromGlobalVectorIndex(index_name, *doc_id, db_cntx, pv); + index->RemoveDoc(*doc_id, db_cntx, pv); + } + } } } diff --git a/src/server/search/doc_index.h b/src/server/search/doc_index.h index e1b758ac91d3..f147aa46e7ef 100644 --- a/src/server/search/doc_index.h +++ b/src/server/search/doc_index.h @@ -33,6 +33,7 @@ using Synonyms = search::Synonyms; std::string_view SearchFieldTypeToString(search::SchemaField::FieldType); struct SerializedSearchDoc { + search::DocId id; std::string key; SearchDocData values; float knn_score; @@ -212,6 +213,7 @@ class ShardDocIndices; class ShardDocIndex { friend class ShardDocIndices; using DocId = search::DocId; + using GlobalDocId = search::GlobalDocId; // Used in FieldsValuesPerDocId to store values for each field per document using FieldsValues = absl::InlinedVector; @@ -219,9 +221,10 @@ class ShardDocIndex { // DocKeyIndex manages mapping document keys to ids and vice versa through a simple interface. struct DocKeyIndex { DocId Add(std::string_view key); - std::optional Remove(std::string_view key); + void Remove(DocId id); std::string_view Get(DocId id) const; + std::optional Find(std::string_view key) const; size_t Size() const; // Get const reference to the internal ids map @@ -262,8 +265,12 @@ class ShardDocIndex { // Return whether base index matches bool Matches(std::string_view key, unsigned obj_code) const; - void AddDoc(std::string_view key, const DbContext& db_cntx, const PrimeValue& pv); - void RemoveDoc(std::string_view key, const DbContext& db_cntx, const PrimeValue& pv); + std::optional GetDocId(std::string_view key, const DbContext& db_cntx); + + std::optional AddDoc(std::string_view key, const DbContext& db_cntx, + const PrimeValue& pv); + + void RemoveDoc(DocId id, const DbContext& db_cntx, const PrimeValue& pv); DocIndexInfo GetInfo() const; @@ -287,13 +294,19 @@ class ShardDocIndex { return key_index_; } - private: - // Clears internal data. Traverses all matching documents and assigns ids. - void Rebuild(const OpArgs& op_args, PMR_NS::memory_resource* mr); + void AddDocToGlobalVectorIndex(std::string_view index_name, ShardDocIndex::DocId doc_id, + const DbContext& db_cntx, const PrimeValue& pv); + void RemoveDocFromGlobalVectorIndex(std::string_view index_name, ShardDocIndex::DocId doc_id, + const DbContext& db_cntx, const PrimeValue& pv); + void RebuildGlobalVectorIndices(std::string_view index_name, const OpArgs& op_args); using LoadedEntry = std::pair>; std::optional LoadEntry(search::DocId id, const OpArgs& op_args) const; + private: + // Clears internal data. Traverses all matching documents and assigns ids. + void Rebuild(const OpArgs& op_args, PMR_NS::memory_resource* mr); + // Behaviour identical to SortIndex::Sort for non-sortable fields that need to be fetched first std::vector KeepTopKSorted(std::vector* ids, size_t limit, const SearchParams::SortOption& sort, diff --git a/src/server/search/global_vector_index.cc b/src/server/search/global_vector_index.cc new file mode 100644 index 000000000000..0de5a1ad035e --- /dev/null +++ b/src/server/search/global_vector_index.cc @@ -0,0 +1,284 @@ +// Copyright 2025, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#include "server/search/global_vector_index.h" + +#include + +#include +#include +#include +#include + +#include "base/logging.h" +#include "core/search/ast_expr.h" +#include "core/search/base.h" +#include "core/search/index_result.h" +#include "core/search/indices.h" +#include "core/search/vector_utils.h" +#include "server/engine_shard.h" +#include "server/engine_shard_set.h" +#include "server/search/doc_accessors.h" +#include "server/search/doc_index.h" +#include "server/transaction.h" +#include "server/tx_base.h" + +namespace dfly { + +GlobalVectorIndex::GlobalVectorIndex(const search::SchemaField::VectorParams& params, + std::string_view index_name, PMR_NS::memory_resource* mr) + : params_(params), index_name_(index_name) { + if (params.use_hnsw) { + vector_index_ = std::make_unique(params, mr); + } else { + vector_index_ = std::make_unique(params, shard_set->size(), mr); + } +} + +GlobalVectorIndex::~GlobalVectorIndex() = default; + +bool GlobalVectorIndex::Add(search::GlobalDocId id, const search::DocumentAccessor& doc, + std::string_view field) { + return vector_index_->Add(id, doc, field); +} + +void GlobalVectorIndex::Remove(search::GlobalDocId id, const search::DocumentAccessor& doc, + std::string_view field) { + vector_index_->Remove(id, doc, field); +} + +std::vector> GlobalVectorIndex::SearchKnnHnsw( + search::HnswVectorIndex* index, const search::AstKnnNode* knn, + const std::optional>& allowed_docs) { + if (allowed_docs) + return index->Knn(knn->vec.first.get(), knn->limit, knn->ef_runtime, *allowed_docs); + else + return index->Knn(knn->vec.first.get(), knn->limit, knn->ef_runtime); +} + +std::vector> GlobalVectorIndex::SearchKnnFlat( + search::FlatVectorIndex* index, const search::AstKnnNode* knn, + const std::optional>& allowed_docs) { + if (allowed_docs) + return index->Knn(knn->vec.first.get(), knn->limit, *allowed_docs); + else + return index->Knn(knn->vec.first.get(), knn->limit); +} + +std::vector GlobalVectorIndex::Search( + const search::AstKnnNode* knn_node, + const std::optional& knn_score_option, + const std::vector& shard_filter_docs, const SearchParams& params, + const CommandContext& cmd_cntx) { + std::vector results(1); + + std::optional> filter_docs_global_ids = std::nullopt; + std::optional> filter_docs_shard_ids = + std::nullopt; + std::map filter_docs_lookup; + + const ShardId shard_size = shard_filter_docs.size(); + + // We have pre filter so all documents should already be fetched + if (knn_node->Filter()) { + std::vector global_ids; + std::vector shard_ids; + shard_ids.resize(shard_size); + for (size_t shard_id = 0; shard_id < shard_size; shard_id++) { + for (auto& doc : shard_filter_docs[shard_id].docs) { + auto global_doc_id = search::CreateGlobalDocId(shard_id, doc.id); + global_ids.emplace_back(global_doc_id); + shard_ids[shard_id].push_back(doc.id); + filter_docs_lookup[global_doc_id] = &doc; + } + } + filter_docs_global_ids = std::move(global_ids); + filter_docs_shard_ids = std::move(shard_ids); + } + + std::vector> knn_results; + if (auto hnsw_index = dynamic_cast(vector_index_.get()); hnsw_index) { + knn_results = SearchKnnHnsw(hnsw_index, knn_node, filter_docs_global_ids); + } else if (auto flat_index = dynamic_cast(vector_index_.get()); + flat_index) { + knn_results = SearchKnnFlat(flat_index, knn_node, filter_docs_shard_ids); + } + + std::vector knn_result_docs; + knn_result_docs.reserve(knn_results.size()); + + // Group by shard with minimal allocations + std::vector>> shard_doc_ids(shard_size); + + for (const auto& [score, global_id] : knn_results) { + if (knn_node->Filter()) { + knn_result_docs.emplace_back(*filter_docs_lookup[global_id]); + // Update knn score + knn_result_docs.back().knn_score = score; + } else { + ShardId shard_id = search::GlobalDocIdShardId(global_id); + search::DocId doc_id = search::GlobalDocIdLocalId(global_id); + shard_doc_ids[shard_id].emplace_back(score, doc_id); + } + } + + if (knn_node->Filter()) { + results[0].total_hits = knn_results.size(); + results[0].docs = std::move(knn_result_docs); + return results; + } + + // Use per-shard vectors to avoid race conditions, but keep them minimal + std::vector> shard_docs(shard_size); + std::atomic index_not_found{false}; + + bool should_fetch_sort_field = false; + if (params.sort_option) { + should_fetch_sort_field = !params.sort_option->IsSame(*knn_score_option); + } + + cmd_cntx.tx->ScheduleSingleHop([&](Transaction* t, EngineShard* es) { + auto* index = es->search_indices()->GetIndex(index_name_); + + if (!index) { + index_not_found.store(true); + return OpStatus::OK; + } + + auto& shard_requests = shard_doc_ids[es->shard_id()]; + if (shard_requests.empty()) + return OpStatus::OK; + + auto& docs_for_shard = shard_docs[es->shard_id()]; + docs_for_shard.reserve(shard_requests.size()); + + // Cache schema reference to avoid repeated lookups + const auto& schema = index->GetInfo().base_index.schema; + + // Optimize serialization based on query type + if (params.ShouldReturnAllFields()) { + // Full serialization for full queries + for (const auto& [score, doc_id] : shard_requests) { + if (auto entry = index->LoadEntry(doc_id, t->GetOpArgs(es))) { + auto& [key, accessor] = *entry; + auto fields = accessor->Serialize(schema); + + // If we use SORT we need to update`sort_score` + search::SortableValue sort_score = std::monostate{}; + if (should_fetch_sort_field) { + sort_score = fields[params.sort_option->field.Name()]; + } + + docs_for_shard.push_back( + {doc_id, std::string{key}, std::move(fields), score, sort_score}); + } + } + } else { + // Selective field serialization + auto return_fields = params.return_fields.value_or(std::vector{}); + + bool should_return_sort_field = false; + if (should_fetch_sort_field) { + for (const auto& return_field : return_fields) { + if (params.sort_option->field.Name() == return_field.Name()) { + should_return_sort_field = true; + } + } + } + + // Sort field is not returned so we need to inject it + if (should_fetch_sort_field && !should_return_sort_field) { + return_fields.push_back(params.sort_option->field); + } + + for (const auto& [score, doc_id] : shard_requests) { + if (auto entry = index->LoadEntry(doc_id, t->GetOpArgs(es))) { + auto& [key, accessor] = *entry; + auto fields = return_fields.empty() ? SearchDocData{} + // NOCONTENT query - no fields needed + : accessor->Serialize(schema, return_fields); + + search::SortableValue sort_score = std::monostate{}; + if (should_fetch_sort_field) { + sort_score = fields[params.sort_option->field.Name()]; + // Erase sort field from returned fields + if (!should_return_sort_field) { + fields.erase(params.sort_option->field.Name()); + } + } + + docs_for_shard.push_back( + {doc_id, std::string{key}, std::move(fields), score, sort_score}); + } + } + } + + return OpStatus::OK; + }); + + // Crete single vector of aggregated documents from all shards + for (auto& docs : shard_docs) { + knn_result_docs.insert(knn_result_docs.end(), std::make_move_iterator(docs.begin()), + std::make_move_iterator(docs.end())); + } + + results[0].total_hits = knn_result_docs.size(); + results[0].docs = std::move(knn_result_docs); + return results; +} + +// Global registry implementation +GlobalVectorIndexRegistry& GlobalVectorIndexRegistry::Instance() { + static GlobalVectorIndexRegistry instance; + return instance; +} + +bool GlobalVectorIndexRegistry::CreateVectorIndex(std::string_view index_name, + std::string_view field_name, + const search::SchemaField::VectorParams& params) { + std::string key = MakeKey(index_name, field_name); + + std::shared_lock lock(registry_mutex_); + + auto it = indices_.find(key); + if (it != indices_.end()) + return false; + + indices_[key] = std::make_shared(params, index_name); + ; + return true; +} + +void GlobalVectorIndexRegistry::RemoveVectorIndex(std::string_view index_name, + std::string_view field_name) { + std::string key = MakeKey(index_name, field_name); + + std::unique_lock lock(registry_mutex_); + + auto it = indices_.find(key); + if (it != indices_.end()) + indices_.erase(it); +} + +std::shared_ptr GlobalVectorIndexRegistry::GetVectorIndex( + std::string_view index_name, std::string_view field_name) const { + std::string key = MakeKey(index_name, field_name); + + std::shared_lock lock(registry_mutex_); + + auto it = indices_.find(key); + return it != indices_.end() ? it->second : nullptr; +} + +void GlobalVectorIndexRegistry::Reset() { + std::unique_lock lock(registry_mutex_); + indices_.clear(); +} + +std::string GlobalVectorIndexRegistry::MakeKey(std::string_view index_name, + std::string_view field_name) const { + return absl::StrCat(index_name, ":", field_name); +} + +} // namespace dfly diff --git a/src/server/search/global_vector_index.h b/src/server/search/global_vector_index.h new file mode 100644 index 000000000000..07a39ce52af2 --- /dev/null +++ b/src/server/search/global_vector_index.h @@ -0,0 +1,83 @@ +// Copyright 2025, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#pragma once + +#include + +#include +#include +#include +#include + +#include "core/search/base.h" +#include "core/search/index_result.h" +#include "core/search/indices.h" +#include "core/search/search.h" +#include "server/common.h" +#include "server/search/doc_index.h" +#include "server/tx_base.h" + +namespace dfly { + +struct KnnScoreSortOption; + +class GlobalVectorIndex { + public: + GlobalVectorIndex(const search::SchemaField::VectorParams& params, std::string_view index_name, + PMR_NS::memory_resource* mr = PMR_NS::get_default_resource()); + + ~GlobalVectorIndex(); + + bool Add(search::GlobalDocId id, const search::DocumentAccessor& doc, std::string_view field); + void Remove(search::GlobalDocId id, const search::DocumentAccessor& doc, std::string_view field); + + std::vector Search( + const search::AstKnnNode* knn, + const std::optional& knn_score_option, + const std::vector& filter_docs, const SearchParams& params, + const CommandContext& cmd_cntx); + + private: + std::vector> SearchKnnHnsw( + search::HnswVectorIndex* index, const search::AstKnnNode* knn, + const std::optional>& allowed_docs); + + std::vector> SearchKnnFlat( + search::FlatVectorIndex* index, const search::AstKnnNode* knn, + const std::optional>& allowed_docs); + + std::unique_ptr> vector_index_; + search::SchemaField::VectorParams params_; + std::string index_name_; +}; + +// Global registry for all vector indices +class GlobalVectorIndexRegistry { + public: + static GlobalVectorIndexRegistry& Instance(); + + // Create global vector index for given index name and field + bool CreateVectorIndex(std::string_view index_name, std::string_view field_name, + const search::SchemaField::VectorParams& params); + + // Remove vector index + void RemoveVectorIndex(std::string_view index_name, std::string_view field_name); + + // Get existing vector index + std::shared_ptr GetVectorIndex(std::string_view index_name, + std::string_view field_name) const; + + // Reset all vector indices + void Reset(); + + private: + GlobalVectorIndexRegistry() = default; + + mutable std::shared_mutex registry_mutex_; + absl::flat_hash_map> indices_; + std::string MakeKey(std::string_view index_name, std::string_view field_name) const; +}; + +} // namespace dfly diff --git a/src/server/search/search_family.cc b/src/server/search/search_family.cc index d9266c749d70..091b5c85c7fa 100644 --- a/src/server/search/search_family.cc +++ b/src/server/search/search_family.cc @@ -34,6 +34,7 @@ #include "server/engine_shard_set.h" #include "server/search/aggregator.h" #include "server/search/doc_index.h" +#include "server/search/global_vector_index.h" #include "server/transaction.h" #include "src/core/overloaded.h" @@ -1084,9 +1085,25 @@ void SearchFamily::FtCreate(CmdArgList args, const CommandContext& cmd_cntx) { } auto idx_ptr = make_shared(std::move(parsed_index).value()); + + for (const auto& [field_ident, field_info] : idx_ptr->schema.fields) { + if (field_info.type == search::SchemaField::VECTOR && + !(field_info.flags & search::SchemaField::NOINDEX)) { + const auto& vparams = std::get(field_info.special_params); + if (!GlobalVectorIndexRegistry::Instance().CreateVectorIndex(idx_name, field_info.short_name, + vparams)) { + cmd_cntx.tx->Conclude(); + return builder->SendError("Index already exists"); + } + } + } + cmd_cntx.tx->Execute( [idx_name, idx_ptr](auto* tx, auto* es) { es->search_indices()->InitIndex(tx->GetOpArgs(es), idx_name, idx_ptr); + if (auto* index = es->search_indices()->GetIndex(idx_name); index) { + index->RebuildGlobalVectorIndices(idx_name, tx->GetOpArgs(es)); + } return OpStatus::OK; }, true); @@ -1158,9 +1175,16 @@ void SearchFamily::FtDropIndex(CmdArgList args, const CommandContext& cmd_cntx) // Parse optional DD (Delete Documents) parameter bool delete_docs = args.size() > 1 && absl::EqualsIgnoreCase(args[1], "DD"); + shared_ptr index_info; atomic_uint num_deleted{0}; auto cb = [&](Transaction* t, EngineShard* es) { + // Get index info from first shard for global cleanup + if (es->shard_id() == 0) { + if (auto* idx = es->search_indices()->GetIndex(idx_name); idx != nullptr) { + index_info = make_shared(idx->GetInfo().base_index); + } + } // Drop the index and get its pointer auto index = es->search_indices()->DropIndex(idx_name); if (!index) @@ -1189,6 +1213,15 @@ void SearchFamily::FtDropIndex(CmdArgList args, const CommandContext& cmd_cntx) cmd_cntx.tx->Execute(cb, true); + if (index_info) { + for (const auto& [field_ident, field_info] : index_info->schema.fields) { + if (field_info.type == search::SchemaField::VECTOR && + !(field_info.flags & search::SchemaField::NOINDEX)) { + GlobalVectorIndexRegistry::Instance().RemoveVectorIndex(idx_name, field_info.short_name); + } + } + } + DCHECK(num_deleted == 0u || num_deleted == shard_set->size()); if (num_deleted == 0u) return cmd_cntx.rb->SendError(IndexNotFoundMsg(idx_name)); @@ -1322,24 +1355,45 @@ void SearchFamily::FtSearch(CmdArgList args, const CommandContext& cmd_cntx) { if (!search_algo.Init(query_str, ¶ms->query_params, ¶ms->optional_filters)) return builder->SendError("Query syntax error"); + std::unique_ptr knn_node = nullptr; + search::AstKnnNode* knn = nullptr; + + if (search_algo.IsKnnQuery()) { + knn_node = search_algo.GetKnnNode(); + knn = std::get_if(knn_node.get()); + } + // Because our coordinator thread may not have a shard, we can't check ahead if the index exists. atomic index_not_found{false}; vector docs(shard_set->size()); - cmd_cntx.tx->ScheduleSingleHop([&](Transaction* t, EngineShard* es) { - if (auto* index = es->search_indices()->GetIndex(index_name); index) - docs[es->shard_id()] = index->Search(t->GetOpArgs(es), *params, &search_algo); - else - index_not_found.store(true, memory_order_relaxed); - return OpStatus::OK; - }); + if (!knn || (knn && knn->Filter())) { + cmd_cntx.tx->ScheduleSingleHop([&](Transaction* t, EngineShard* es) { + if (auto* index = es->search_indices()->GetIndex(index_name); index) + docs[es->shard_id()] = index->Search(t->GetOpArgs(es), *params, &search_algo); + else + index_not_found.store(true, memory_order_relaxed); + return OpStatus::OK; + }); - if (index_not_found.load()) - return builder->SendError(string{index_name} + ": no such index"); + if (index_not_found.load()) + return builder->SendError(string{index_name} + ": no such index"); + + for (const auto& res : docs) { + if (res.error) + return builder->SendError(*res.error); + } + } + + if (knn_node) { + auto vector_index = + GlobalVectorIndexRegistry::Instance().GetVectorIndex(index_name, knn->field); + + if (!vector_index) { + return builder->SendError(string{index_name} + ": no such index"); + } - for (const auto& res : docs) { - if (res.error) - return builder->SendError(*res.error); + docs = vector_index->Search(knn, search_algo.GetKnnScoreSortOption(), docs, *params, cmd_cntx); } SearchReply(*params, search_algo.GetKnnScoreSortOption(), absl::MakeSpan(docs), builder); diff --git a/src/server/search/search_family_test.cc b/src/server/search/search_family_test.cc index 5fa7acce796a..ea982dc8e186 100644 --- a/src/server/search/search_family_test.cc +++ b/src/server/search/search_family_test.cc @@ -2446,10 +2446,6 @@ TEST_F(SearchFamilyTest, VectorIndexOperations) { // Basic star search auto star_search = Run({"ft.search", "vector_idx", "*"}); EXPECT_THAT(star_search, AreDocIds("vec:1", "vec:2", "vec:3", "vec:4", "vec:5")); - - // Search by vector field presence - auto vec_field_search = Run({"ft.search", "vector_idx", "@vec:*"}); - EXPECT_THAT(vec_field_search, AreDocIds("vec:1", "vec:2", "vec:3", "vec:4", "vec:5")); } // Test to verify that @field:* syntax works with sortable fields