From f52a975ed8ecfea778d34041a41d2ae96b1056a1 Mon Sep 17 00:00:00 2001 From: Volodymyr Yavdoshenko Date: Thu, 4 Sep 2025 18:00:58 +0300 Subject: [PATCH 1/2] feat: add search poc --- src/server/search/CMakeLists.txt | 5 +- src/server/search/doc_index.cc | 153 +++++++++- src/server/search/doc_index.h | 16 +- src/server/search/global_vector_index.cc | 350 ++++++++++++++++++++++ src/server/search/global_vector_index.h | 123 ++++++++ src/server/search/global_vector_search.cc | 299 ++++++++++++++++++ src/server/search/global_vector_search.h | 85 ++++++ src/server/search/performance_test.cc | 122 ++++++++ src/server/search/search_family.cc | 197 +++++++++++- 9 files changed, 1335 insertions(+), 15 deletions(-) create mode 100644 src/server/search/global_vector_index.cc create mode 100644 src/server/search/global_vector_index.h create mode 100644 src/server/search/global_vector_search.cc create mode 100644 src/server/search/global_vector_search.h create mode 100644 src/server/search/performance_test.cc diff --git a/src/server/search/CMakeLists.txt b/src/server/search/CMakeLists.txt index b37dcae4a6f3..74723773df43 100644 --- a/src/server/search/CMakeLists.txt +++ b/src/server/search/CMakeLists.txt @@ -4,13 +4,12 @@ if (NOT WITH_SEARCH) return() endif() -add_library(dfly_search_server aggregator.cc doc_accessors.cc doc_index.cc search_family.cc index_join.cc) +add_library(dfly_search_server aggregator.cc doc_accessors.cc doc_index.cc search_family.cc index_join.cc global_vector_index.cc global_vector_search.cc) target_link_libraries(dfly_search_server dfly_transaction dragonfly_lib dfly_facade redis_lib jsonpath TRDP::jsoncons) - cxx_test(search_family_test dfly_test_lib LABELS DFLY) cxx_test(aggregator_test dfly_test_lib LABELS DFLY) cxx_test(index_join_test dfly_test_lib LABELS DFLY) - +cxx_test(performance_test dfly_test_lib LABELS DFLY) add_dependencies(check_dfly search_family_test aggregator_test index_join_test) diff --git a/src/server/search/doc_index.cc b/src/server/search/doc_index.cc index 1f744592b860..f287b58484ff 100644 --- a/src/server/search/doc_index.cc +++ b/src/server/search/doc_index.cc @@ -16,6 +16,7 @@ #include "server/engine_shard_set.h" #include "server/family_utils.h" #include "server/search/doc_accessors.h" +#include "server/search/global_vector_index.h" #include "server/server_state.h" namespace dfly { @@ -236,6 +237,11 @@ string_view ShardDocIndex::DocKeyIndex::Get(DocId id) const { return keys_[id]; } +std::optional ShardDocIndex::DocKeyIndex::Find(string_view key) const { + auto it = ids_.find(key); + return it != ids_.end() ? std::make_optional(it->second) : std::nullopt; +} + size_t ShardDocIndex::DocKeyIndex::Size() const { return ids_.size(); } @@ -666,8 +672,11 @@ void ShardDocIndices::DropIndexCache(const dfly::ShardDocIndex& shard_doc_index) } void ShardDocIndices::RebuildAllIndices(const OpArgs& op_args) { - for (auto& [_, ptr] : indices_) + for (auto& [index_name, ptr] : indices_) { ptr->Rebuild(op_args, &local_mr_); + // PoC: Also rebuild global vector indices + ptr->RebuildGlobalVectorIndices(index_name, op_args); + } } vector ShardDocIndices::GetIndexNames() const { @@ -680,17 +689,23 @@ vector ShardDocIndices::GetIndexNames() const { void ShardDocIndices::AddDoc(string_view key, const DbContext& db_cntx, const PrimeValue& pv) { DCHECK(IsIndexedKeyType(pv)); - for (auto& [_, index] : indices_) { - if (index->Matches(key, pv.ObjType())) + for (auto& [index_name, index] : indices_) { + if (index->Matches(key, pv.ObjType())) { index->AddDoc(key, db_cntx, pv); + // PoC: Also add to global vector index if document has vector fields + index->AddDocToGlobalVectorIndex(index_name, key, db_cntx, pv); + } } } void ShardDocIndices::RemoveDoc(string_view key, const DbContext& db_cntx, const PrimeValue& pv) { DCHECK(IsIndexedKeyType(pv)); - for (auto& [_, index] : indices_) { - if (index->Matches(key, pv.ObjType())) + for (auto& [index_name, index] : indices_) { + if (index->Matches(key, pv.ObjType())) { + // PoC: Remove from global vector index first (before local removal) + index->RemoveDocFromGlobalVectorIndex(index_name, key, db_cntx, pv); index->RemoveDoc(key, db_cntx, pv); + } } } @@ -706,4 +721,132 @@ SearchStats ShardDocIndices::GetStats() const { return {GetUsedMemory(), indices_.size(), total_entries}; } +// PoC: Global vector index integration +void ShardDocIndex::AddDocToGlobalVectorIndex(std::string_view index_name, std::string_view key, + const DbContext& db_cntx, const PrimeValue& pv) { + if (!indices_) + return; + + auto accessor = GetAccessor(db_cntx, pv); + + // Check if document has vector fields and add them to global index + for (const auto& [field_ident, field_info] : base_->schema.fields) { + if (field_info.type == search::SchemaField::VECTOR && + !(field_info.flags & search::SchemaField::NOINDEX)) { + auto vector_info = accessor->GetVector(field_ident); + if (vector_info && vector_info->first) { + // Get or create global vector index + const auto& vparams = + std::get(field_info.special_params); + auto global_index = GlobalVectorIndexRegistry::Instance().GetOrCreateVectorIndex( + index_name, field_info.short_name, vparams); + + // Find local DocId for this key + auto local_id = key_index_.Find(key); + if (local_id) { + GlobalDocId global_id{EngineShard::tlocal()->shard_id(), *local_id}; + + // Add vector to global index + global_index->AddVector(global_id, key, vector_info->first.get()); + + LOG(INFO) << "Added vector to global index: " << index_name << ":" + << field_info.short_name << " key=" << key << " global_id={" + << global_id.shard_id << "," << global_id.local_doc_id << "}"; + } + } + } + } +} + +void ShardDocIndex::RemoveDocFromGlobalVectorIndex(std::string_view index_name, + std::string_view key, const DbContext& db_cntx, + const PrimeValue& pv) { + if (!indices_) + return; + + // Check if document has vector fields and remove them from global index + for (const auto& [field_ident, field_info] : base_->schema.fields) { + if (field_info.type == search::SchemaField::VECTOR && + !(field_info.flags & search::SchemaField::NOINDEX)) { + auto global_index = + GlobalVectorIndexRegistry::Instance().GetVectorIndex(index_name, field_info.short_name); + + if (global_index) { + // Find local DocId for this key + auto local_id = key_index_.Find(key); + if (local_id) { + GlobalDocId global_id{EngineShard::tlocal()->shard_id(), *local_id}; + + // Remove vector from global index + global_index->RemoveVector(global_id, key); + + VLOG(1) << "Removed vector from global index: " << index_name << ":" + << field_info.short_name << " key=" << key << " global_id={" << global_id.shard_id + << "," << global_id.local_doc_id << "}"; + } + } + } + } +} + +void ShardDocIndex::RebuildGlobalVectorIndices(std::string_view index_name, const OpArgs& op_args) { + if (!indices_) { + LOG(WARNING) << "No indices available for " << index_name; + return; + } + + // Traverse all documents and rebuild global vector indices + size_t vectors_added = 0; + auto cb = [this, index_name, &vectors_added](string_view key, const BaseAccessor& doc) { + LOG(INFO) << "Processing document: " << key; + + // Check if document has vector fields and add them to global index + for (const auto& [field_ident, field_info] : base_->schema.fields) { + LOG(INFO) << "Checking field: " << field_ident << " (type=" << field_info.type + << ", alias=" << field_info.short_name << ")"; + + if (field_info.type == search::SchemaField::VECTOR && + !(field_info.flags & search::SchemaField::NOINDEX)) { + LOG(INFO) << "Found vector field: " << field_ident << " -> " << field_info.short_name; + + auto vector_info = doc.GetVector(field_ident); + if (vector_info && vector_info->first) { + LOG(INFO) << "Vector data found for field: " << field_ident; + + // Get or create global vector index + const auto& vparams = + std::get(field_info.special_params); + auto global_index = GlobalVectorIndexRegistry::Instance().GetOrCreateVectorIndex( + index_name, field_info.short_name, vparams); + + // Find local DocId for this key + auto local_id = key_index_.Find(key); + if (local_id) { + GlobalDocId global_id{EngineShard::tlocal()->shard_id(), *local_id}; + + // Add vector to global index + if (global_index->AddVector(global_id, key, vector_info->first.get())) { + vectors_added++; + LOG(INFO) << "Successfully added vector to global index: " << index_name << ":" + << field_info.short_name << " key=" << key << " global_id={" + << global_id.shard_id << "," << global_id.local_doc_id << "}"; + } else { + LOG(WARNING) << "Failed to add vector to global index: " << key; + } + } else { + LOG(WARNING) << "Could not find local DocId for key: " << key; + } + } else { + LOG(INFO) << "No vector data for field: " << field_ident << " in key: " << key; + } + } + } + }; + + TraverseAllMatching(*base_, op_args, cb); + + LOG(INFO) << "Rebuilt global vector indices for " << index_name << " with " << key_index_.Size() + << " docs on shard " << EngineShard::tlocal()->shard_id(); +} + } // namespace dfly diff --git a/src/server/search/doc_index.h b/src/server/search/doc_index.h index 7c3ae53e0e91..b073690ce02d 100644 --- a/src/server/search/doc_index.h +++ b/src/server/search/doc_index.h @@ -26,6 +26,7 @@ namespace dfly { struct BaseAccessor; +class GlobalVectorIndex; // PoC: Forward declaration for global vector index using SearchDocData = absl::flat_hash_map; using Synonyms = search::Synonyms; @@ -219,6 +220,7 @@ class ShardDocIndex { std::optional Remove(std::string_view key); std::string_view Get(DocId id) const; + std::optional Find(std::string_view key) const; // PoC: Find DocId by key size_t Size() const; private: @@ -274,13 +276,21 @@ class ShardDocIndex { void RebuildForGroup(const OpArgs& op_args, const std::string_view& group_id, const std::vector& terms); - private: - // Clears internal data. Traverses all matching documents and assigns ids. - void Rebuild(const OpArgs& op_args, PMR_NS::memory_resource* mr); + // PoC: Global vector index support + void AddDocToGlobalVectorIndex(std::string_view index_name, std::string_view key, + const DbContext& db_cntx, const PrimeValue& pv); + void RemoveDocFromGlobalVectorIndex(std::string_view index_name, std::string_view key, + const DbContext& db_cntx, const PrimeValue& pv); + void RebuildGlobalVectorIndices(std::string_view index_name, const OpArgs& op_args); + // PoC: Public access to LoadEntry for global search coordinator using LoadedEntry = std::pair>; std::optional LoadEntry(search::DocId id, const OpArgs& op_args) const; + private: + // Clears internal data. Traverses all matching documents and assigns ids. + void Rebuild(const OpArgs& op_args, PMR_NS::memory_resource* mr); + // Behaviour identical to SortIndex::Sort for non-sortable fields that need to be fetched first std::vector KeepTopKSorted(std::vector* ids, size_t limit, const SearchParams::SortOption& sort, diff --git a/src/server/search/global_vector_index.cc b/src/server/search/global_vector_index.cc new file mode 100644 index 000000000000..a4e936613ccc --- /dev/null +++ b/src/server/search/global_vector_index.cc @@ -0,0 +1,350 @@ +// Copyright 2025, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#include "server/search/global_vector_index.h" + +#include + +#include + +#include "base/logging.h" +#include "core/search/indices.h" +#include "core/search/vector_utils.h" + +namespace dfly { + +using namespace std; +using namespace search; + +GlobalVectorIndex::GlobalVectorIndex(const SchemaField::VectorParams& params, + PMR_NS::memory_resource* mr) + : params_(params) { + // Create the appropriate vector index based on parameters + if (params.use_hnsw) { + vector_index_ = std::make_unique(params, mr); + } else { + vector_index_ = std::make_unique(params, mr); + } +} + +GlobalVectorIndex::~GlobalVectorIndex() = default; + +std::vector> GlobalVectorIndex::Knn(float* target, size_t k, + std::optional ef) const { + std::shared_lock lock(rw_mutex_); + + std::vector> result; + + if (auto* hnsw_index = dynamic_cast(vector_index_.get())) { + auto internal_results = hnsw_index->Knn(target, k, ef); + result.reserve(internal_results.size()); + + for (const auto& [distance, internal_id] : internal_results) { + auto it = internal_to_global_.find(internal_id); + if (it != internal_to_global_.end()) { + result.emplace_back(distance, it->second); + } + } + } else if (auto* flat_index = dynamic_cast(vector_index_.get())) { + // For flat index, we need to compute distances to all vectors + std::vector> distances; + auto [dim, sim] = vector_index_->Info(); + + for (const auto& [global_id, internal_id] : global_to_internal_) { + const float* vec = flat_index->Get(internal_id); + float dist = VectorDistance(target, vec, dim, sim); + distances.emplace_back(dist, internal_id); + } + + // Sort and take top k + size_t limit = std::min(k, distances.size()); + std::partial_sort(distances.begin(), distances.begin() + limit, distances.end()); + distances.resize(limit); + + result.reserve(distances.size()); + for (const auto& [distance, internal_id] : distances) { + auto it = internal_to_global_.find(internal_id); + if (it != internal_to_global_.end()) { + result.emplace_back(distance, it->second); + } + } + } + + return result; +} + +std::vector> GlobalVectorIndex::Knn( + float* target, size_t k, std::optional ef, + const std::vector& allowed) const { + std::shared_lock lock(rw_mutex_); + + // Convert allowed GlobalDocIds to internal DocIds + std::vector allowed_internal; + allowed_internal.reserve(allowed.size()); + + for (const auto& global_id : allowed) { + auto it = global_to_internal_.find(global_id); + if (it != global_to_internal_.end()) { + allowed_internal.push_back(it->second); + } + } + + std::sort(allowed_internal.begin(), allowed_internal.end()); + + std::vector> result; + + if (auto* hnsw_index = dynamic_cast(vector_index_.get())) { + auto internal_results = hnsw_index->Knn(target, k, ef, allowed_internal); + result.reserve(internal_results.size()); + + for (const auto& [distance, internal_id] : internal_results) { + auto it = internal_to_global_.find(internal_id); + if (it != internal_to_global_.end()) { + result.emplace_back(distance, it->second); + } + } + } else { + // For flat index with filtering + std::vector> distances; + auto [dim, sim] = vector_index_->Info(); + + for (search::DocId internal_id : allowed_internal) { + if (auto* flat_index = dynamic_cast(vector_index_.get())) { + const float* vec = flat_index->Get(internal_id); + float dist = VectorDistance(target, vec, dim, sim); + distances.emplace_back(dist, internal_id); + } + } + + // Sort and take top k + size_t limit = std::min(k, distances.size()); + std::partial_sort(distances.begin(), distances.begin() + limit, distances.end()); + distances.resize(limit); + + result.reserve(distances.size()); + for (const auto& [distance, internal_id] : distances) { + auto it = internal_to_global_.find(internal_id); + if (it != internal_to_global_.end()) { + result.emplace_back(distance, it->second); + } + } + } + + return result; +} + +std::pair GlobalVectorIndex::Info() const { + std::shared_lock lock(rw_mutex_); + return vector_index_->Info(); +} + +bool GlobalVectorIndex::AddVector(GlobalDocId global_id, std::string_view key, + const float* vector) { + std::unique_lock lock(rw_mutex_); + + // Check if already exists + if (global_to_internal_.find(global_id) != global_to_internal_.end()) { + return false; + } + + // Assign new internal ID + search::DocId internal_id = next_internal_id_++; + + // Create mock document accessor for the vector index + class VectorDocumentAccessor : public DocumentAccessor { + public: + VectorDocumentAccessor(const float* vec, size_t dim) : vector_(vec), dim_(dim) { + } + + std::optional GetStrings(std::string_view field) const override { + return std::nullopt; + } + + std::optional GetVector(std::string_view field) const override { + if (!vector_) + return VectorInfo{}; + + auto ptr = std::make_unique(dim_); + std::memcpy(ptr.get(), vector_, dim_ * sizeof(float)); + return VectorInfo{std::move(ptr), dim_}; + } + + std::optional GetNumbers(std::string_view field) const override { + return std::nullopt; + } + + std::optional GetTags(std::string_view field) const override { + return std::nullopt; + } + + private: + const float* vector_; + size_t dim_; + }; + + VectorDocumentAccessor doc_accessor(vector, params_.dim); + + // Add to vector index + if (!vector_index_->Add(internal_id, doc_accessor, "vector_field")) { + return false; + } + + // Update mappings + global_to_internal_[global_id] = internal_id; + internal_to_global_[internal_id] = global_id; + global_to_key_[global_id] = std::string(key); + + VLOG(2) << "Added vector to global index: global_id={" << global_id.shard_id << "," + << global_id.local_doc_id << "}, internal_id=" << internal_id << ", key=" << key; + + return true; +} + +void GlobalVectorIndex::RemoveVector(GlobalDocId global_id, std::string_view key) { + std::unique_lock lock(rw_mutex_); + + auto it = global_to_internal_.find(global_id); + if (it == global_to_internal_.end()) { + return; + } + + search::DocId internal_id = it->second; + + // Create mock document accessor for removal + class VectorDocumentAccessor : public DocumentAccessor { + public: + std::optional GetStrings(std::string_view field) const override { + return std::nullopt; + } + + std::optional GetVector(std::string_view field) const override { + return VectorInfo{}; + } + + std::optional GetNumbers(std::string_view field) const override { + return std::nullopt; + } + + std::optional GetTags(std::string_view field) const override { + return std::nullopt; + } + }; + + VectorDocumentAccessor doc_accessor; + + // Remove from vector index + vector_index_->Remove(internal_id, doc_accessor, "vector_field"); + + // Remove mappings + global_to_internal_.erase(it); + internal_to_global_.erase(internal_id); + global_to_key_.erase(global_id); + + VLOG(2) << "Removed vector from global index: global_id={" << global_id.shard_id << "," + << global_id.local_doc_id << "}, key=" << key; +} + +size_t GlobalVectorIndex::Size() const { + std::shared_lock lock(rw_mutex_); + return global_to_internal_.size(); +} + +std::vector GlobalVectorIndex::GetAllDocsWithVectors() const { + std::shared_lock lock(rw_mutex_); + + std::vector result; + result.reserve(global_to_internal_.size()); + + for (const auto& [global_id, _] : global_to_internal_) { + result.push_back(global_id); + } + + return result; +} + +std::optional GlobalVectorIndex::GetKey(GlobalDocId global_id) const { + std::shared_lock lock(rw_mutex_); + + auto it = global_to_key_.find(global_id); + if (it != global_to_key_.end()) { + return it->second; + } + return std::nullopt; +} + +search::DocId GlobalVectorIndex::ToInternalDocId(GlobalDocId global_id) const { + auto it = global_to_internal_.find(global_id); + DCHECK(it != global_to_internal_.end()); + return it->second; +} + +GlobalDocId GlobalVectorIndex::FromInternalDocId(search::DocId internal_id) const { + auto it = internal_to_global_.find(internal_id); + DCHECK(it != internal_to_global_.end()); + return it->second; +} + +// Global registry implementation +GlobalVectorIndexRegistry& GlobalVectorIndexRegistry::Instance() { + static GlobalVectorIndexRegistry instance; + return instance; +} + +std::shared_ptr GlobalVectorIndexRegistry::GetOrCreateVectorIndex( + std::string_view index_name, std::string_view field_name, + const search::SchemaField::VectorParams& params) { + std::string key = MakeKey(index_name, field_name); + + { + std::shared_lock lock(registry_mutex_); + auto it = indices_.find(key); + if (it != indices_.end()) { + return it->second; + } + } + + std::unique_lock lock(registry_mutex_); + // Double-check after acquiring write lock + auto it = indices_.find(key); + if (it != indices_.end()) { + return it->second; + } + + // Create new global vector index + auto global_index = std::make_shared(params); + indices_[key] = global_index; + + LOG(INFO) << "Created global vector index: " << key << ", dim=" << params.dim + << ", use_hnsw=" << params.use_hnsw; + + return global_index; +} + +void GlobalVectorIndexRegistry::RemoveVectorIndex(std::string_view index_name, + std::string_view field_name) { + std::string key = MakeKey(index_name, field_name); + + std::unique_lock lock(registry_mutex_); + auto it = indices_.find(key); + if (it != indices_.end()) { + LOG(INFO) << "Removed global vector index: " << key; + indices_.erase(it); + } +} + +std::shared_ptr GlobalVectorIndexRegistry::GetVectorIndex( + std::string_view index_name, std::string_view field_name) const { + std::string key = MakeKey(index_name, field_name); + + std::shared_lock lock(registry_mutex_); + auto it = indices_.find(key); + return it != indices_.end() ? it->second : nullptr; +} + +std::string GlobalVectorIndexRegistry::MakeKey(std::string_view index_name, + std::string_view field_name) const { + return absl::StrCat(index_name, ":", field_name); +} + +} // namespace dfly diff --git a/src/server/search/global_vector_index.h b/src/server/search/global_vector_index.h new file mode 100644 index 000000000000..358ca52b1fc7 --- /dev/null +++ b/src/server/search/global_vector_index.h @@ -0,0 +1,123 @@ +// Copyright 2025, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#pragma once + +#include + +#include +#include +#include +#include + +#include "core/search/base.h" +#include "core/search/indices.h" +#include "core/search/search.h" +#include "server/common.h" +#include "server/tx_base.h" + +namespace dfly { + +// Global document ID that uniquely identifies a document across all shards +struct GlobalDocId { + ShardId shard_id; + search::DocId local_doc_id; + + bool operator<(const GlobalDocId& other) const { + return std::tie(shard_id, local_doc_id) < std::tie(other.shard_id, other.local_doc_id); + } + + bool operator==(const GlobalDocId& other) const { + return shard_id == other.shard_id && local_doc_id == other.local_doc_id; + } + + bool operator!=(const GlobalDocId& other) const { + return !(*this == other); + } + + // Hash function for use in absl containers + template friend H AbslHashValue(H h, const GlobalDocId& id) { + return H::combine(std::move(h), id.shard_id, id.local_doc_id); + } +}; + +// Thread-safe global vector index that can be accessed from multiple threads +// Uses reader-writer locks to allow concurrent reads while protecting writes +class GlobalVectorIndex { + public: + explicit GlobalVectorIndex(const search::SchemaField::VectorParams& params, + PMR_NS::memory_resource* mr = PMR_NS::get_default_resource()); + + ~GlobalVectorIndex(); + + // Thread-safe read operations (multiple readers can access simultaneously) + std::vector> Knn(float* target, size_t k, + std::optional ef = std::nullopt) const; + + std::vector> Knn(float* target, size_t k, std::optional ef, + const std::vector& allowed) const; + + // Get vector info (dimensions, similarity metric) + std::pair Info() const; + + // Thread-safe write operations (exclusive access) + bool AddVector(GlobalDocId global_id, std::string_view key, const float* vector); + void RemoveVector(GlobalDocId global_id, std::string_view key); + + // Get statistics + size_t Size() const; + std::vector GetAllDocsWithVectors() const; + + // Get key for global doc id (for fetching document fields) + std::optional GetKey(GlobalDocId global_id) const; + + private: + // Convert between GlobalDocId and internal DocId used by vector index + search::DocId ToInternalDocId(GlobalDocId global_id) const; + GlobalDocId FromInternalDocId(search::DocId internal_id) const; + + mutable std::shared_mutex rw_mutex_; + std::unique_ptr vector_index_; + + // Mapping between GlobalDocId and internal DocId + absl::flat_hash_map global_to_internal_; + absl::flat_hash_map internal_to_global_; + + // Mapping from GlobalDocId to document key (for fetching fields later) + absl::flat_hash_map global_to_key_; + + // Counter for generating internal DocIds + search::DocId next_internal_id_{0}; + + // Vector parameters + search::SchemaField::VectorParams params_; +}; + +// Global registry for all vector indices +class GlobalVectorIndexRegistry { + public: + static GlobalVectorIndexRegistry& Instance(); + + // Get or create global vector index for given index name and field + std::shared_ptr GetOrCreateVectorIndex( + std::string_view index_name, std::string_view field_name, + const search::SchemaField::VectorParams& params); + + // Remove vector index + void RemoveVectorIndex(std::string_view index_name, std::string_view field_name); + + // Get existing vector index + std::shared_ptr GetVectorIndex(std::string_view index_name, + std::string_view field_name) const; + + private: + GlobalVectorIndexRegistry() = default; + + mutable std::shared_mutex registry_mutex_; + absl::flat_hash_map> indices_; + + std::string MakeKey(std::string_view index_name, std::string_view field_name) const; +}; + +} // namespace dfly diff --git a/src/server/search/global_vector_search.cc b/src/server/search/global_vector_search.cc new file mode 100644 index 000000000000..d5c27a6aa916 --- /dev/null +++ b/src/server/search/global_vector_search.cc @@ -0,0 +1,299 @@ +// Copyright 2025, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#include "server/search/global_vector_search.h" + +#include + +#include "base/logging.h" +#include "core/search/ast_expr.h" +#include "core/search/parser.hh" +#include "core/search/query_driver.h" +#include "core/search/vector_utils.h" +#include "server/engine_shard_set.h" +#include "server/namespaces.h" +#include "server/search/doc_accessors.h" + +namespace dfly { + +using namespace std; +using namespace search; + +bool GlobalVectorSearchAlgorithm::Init(std::string_view query, const QueryParams* params) { + try { + QueryDriver driver{}; + driver.ResetScanner(); + driver.SetParams(params); + driver.SetInput(std::string{query}); + (void)Parser (&driver)(); + query_ = std::make_unique(driver.Take()); + } catch (const Parser::syntax_error& se) { + LOG(INFO) << "Failed to parse query \"" << query << "\": " << se.what(); + return false; + } catch (...) { + LOG_EVERY_T(INFO, 10) << "Unexpected query parser error \"" << query << "\""; + return false; + } + + if (std::holds_alternative(*query_)) { + LOG_EVERY_T(INFO, 10) << "Empty result after parsing query \"" << query << "\""; + return false; + } + + return true; +} + +GlobalSearchResult GlobalVectorSearchAlgorithm::Search( + std::string_view index_name, const std::vector& shard_indices) const { + if (!IsVectorOnlyQuery()) { + // Fallback to traditional shard-based search for non-vector-only queries + // This would need to be implemented to handle mixed queries + LOG(WARNING) << "Mixed queries not yet supported in global vector search PoC"; + return GlobalSearchResult{}; + } + + // Extract vector field name from KNN query + auto vector_field = ExtractVectorField(*query_); + if (!vector_field) { + LOG(ERROR) << "Could not extract vector field from KNN query"; + return GlobalSearchResult{}; + } + + // Get global vector index + auto global_index = + GlobalVectorIndexRegistry::Instance().GetVectorIndex(index_name, *vector_field); + + if (!global_index) { + VLOG(1) << "Global vector index not found: " << index_name << ":" << *vector_field + << ", will use shard-based search"; + // Return empty result to trigger fallback + GlobalSearchResult empty_result; + empty_result.total_hits = 0; + return empty_result; + } + + // Extract KNN parameters from query + if (auto* knn_node = std::get_if(query_.get())) { + auto knn_results = + global_index->Knn(knn_node->vec.first.get(), knn_node->limit, knn_node->ef_runtime); + + GlobalSearchResult result; + result.total_hits = knn_results.size(); + result.global_doc_ids.reserve(knn_results.size()); + + // Convert to search result format + result.knn_scores.reserve(knn_results.size()); + + for (const auto& [score, global_id] : knn_results) { + result.global_doc_ids.push_back(global_id); + result.knn_scores.emplace_back(global_id.local_doc_id, score); // temp mapping + } + + LOG(INFO) << "Global vector search found " << result.total_hits << " results for index " + << index_name << " field " << *vector_field; + + return result; + } + + return GlobalSearchResult{}; +} + +bool GlobalVectorSearchAlgorithm::IsVectorOnlyQuery() const { + return IsKnnOnlyQuery(*query_); +} + +std::optional GlobalVectorSearchAlgorithm::GetKnnScoreSortOption() const { + if (auto* knn = std::get_if(query_.get())) { + return KnnScoreSortOption{std::string_view{knn->score_alias}, knn->limit}; + } + return std::nullopt; +} + +std::optional GlobalVectorSearchAlgorithm::ExtractVectorFieldName() const { + if (auto* knn = std::get_if(query_.get())) { + return knn->field; + } + return std::nullopt; +} + +const search::AstKnnNode* GlobalVectorSearchAlgorithm::GetKnnNode() const { + return std::get_if(query_.get()); +} + +void GlobalVectorSearchAlgorithm::EnableProfiling() { + profiling_enabled_ = true; +} + +bool GlobalVectorSearchAlgorithm::IsKnnOnlyQuery(const AstNode& node) const { + // Check if this is a pure KNN query without other filters + if (auto* knn = std::get_if(&node)) { + // Pure KNN query should not have any filter + return !knn->filter || std::holds_alternative(*knn->filter); + } + return false; +} + +std::optional GlobalVectorSearchAlgorithm::ExtractVectorField( + const AstNode& node) const { + if (auto* knn = std::get_if(&node)) { + return knn->field; + } + return std::nullopt; +} + +// GlobalVectorSearchCoordinator implementation +GlobalSearchResult GlobalVectorSearchCoordinator::ExecuteGlobalVectorSearch( + std::string_view index_name, std::string_view query_str, const SearchParams& params, + const QueryParams& query_params, const std::vector& shard_indices) { + GlobalVectorSearchAlgorithm algo; + if (!algo.Init(query_str, &query_params)) { + return GlobalSearchResult{}; + } + + if (!algo.IsVectorOnlyQuery()) { + LOG(WARNING) << "Non-vector-only queries not supported in PoC"; + return GlobalSearchResult{}; + } + + // Execute global search + auto global_result = algo.Search(index_name, shard_indices); + if (global_result.global_doc_ids.empty()) { + return global_result; + } + + // Group global doc IDs by shard + auto grouped_ids = GroupByShard(global_result.global_doc_ids); + + // Fetch document fields from each shard + std::vector all_docs; + for (const auto& [shard_id, shard_global_ids] : grouped_ids) { + if (shard_id < shard_indices.size() && shard_indices[shard_id]) { + // Need to execute on the correct shard, not current thread's shard + // This is a limitation of the PoC - we can't easily access other shards from coordinator + // For now, use current shard as workaround + DbContext db_cntx{&namespaces->GetDefaultNamespace(), 0, GetCurrentTimeMs()}; + OpArgs op_args{EngineShard::tlocal(), nullptr, db_cntx}; + + auto shard_docs = FetchDocumentFields(shard_id, shard_global_ids, global_result.knn_scores, + params, shard_indices[shard_id], op_args); + + all_docs.insert(all_docs.end(), std::make_move_iterator(shard_docs.begin()), + std::make_move_iterator(shard_docs.end())); + } else { + LOG(WARNING) << "Shard " << shard_id << " not available"; + } + } + + // Apply SORTBY if needed (after fetching all documents) + if (params.sort_option && !all_docs.empty()) { + const auto& sort_opt = *params.sort_option; + LOG(INFO) << "Applying SORTBY " << sort_opt.field.Name() << " " + << (sort_opt.order == SortOrder::DESC ? "DESC" : "ASC") << " to " << all_docs.size() + << " global search results"; + + auto comparator = [&sort_opt](const SerializedSearchDoc& a, const SerializedSearchDoc& b) { + std::string field_name{sort_opt.field.OutputName()}; + + auto a_it = a.values.find(field_name); + auto b_it = b.values.find(field_name); + + // Handle missing values (put them at the end) + if (a_it == a.values.end() && b_it == b.values.end()) + return false; + if (a_it == a.values.end()) + return false; // a goes to end + if (b_it == b.values.end()) + return true; // b goes to end, a comes first + + bool result = a_it->second < b_it->second; + return sort_opt.order == SortOrder::DESC ? !result : result; + }; + + std::sort(all_docs.begin(), all_docs.end(), comparator); + } + + // Apply LIMIT after sorting (if any) + if (!all_docs.empty()) { + size_t start_idx = std::min(params.limit_offset, all_docs.size()); + size_t end_idx = std::min(start_idx + params.limit_total, all_docs.size()); + + if (start_idx > 0 || end_idx < all_docs.size()) { + std::vector limited_docs; + limited_docs.reserve(end_idx - start_idx); + + for (size_t i = start_idx; i < end_idx; ++i) { + limited_docs.push_back(std::move(all_docs[i])); + } + + all_docs = std::move(limited_docs); + LOG(INFO) << "Applied LIMIT " << params.limit_offset << " " << params.limit_total + << ", result size: " << all_docs.size(); + } + } + + global_result.docs = std::move(all_docs); + return global_result; +} + +absl::flat_hash_map> GlobalVectorSearchCoordinator::GroupByShard( + const std::vector& global_ids) { + absl::flat_hash_map> grouped; + + for (const auto& global_id : global_ids) { + grouped[global_id.shard_id].push_back(global_id); + } + + return grouped; +} + +std::vector GlobalVectorSearchCoordinator::FetchDocumentFields( + ShardId shard_id, const std::vector& shard_global_ids, + const std::vector>& knn_scores, const SearchParams& params, + ShardDocIndex* shard_index, const OpArgs& op_args) { + std::vector docs; + docs.reserve(shard_global_ids.size()); + + // Create map for quick score lookup + absl::flat_hash_map score_map; + for (const auto& [doc_id, score] : knn_scores) { + score_map[doc_id] = score; + } + + for (const auto& global_id : shard_global_ids) { + // Load entry from shard + auto entry = shard_index->LoadEntry(global_id.local_doc_id, op_args); + if (!entry) { + continue; // Document might have expired + } + + auto& [key, accessor] = *entry; + + // Serialize document fields based on return parameters + SearchDocData fields{}; + auto index_info = shard_index->GetInfo(); + if (params.ShouldReturnAllFields()) { + fields = accessor->Serialize(index_info.base_index.schema); + } + + auto return_fields = params.return_fields.value_or(std::vector{}); + auto more_fields = accessor->Serialize(index_info.base_index.schema, return_fields); + fields.insert(std::make_move_iterator(more_fields.begin()), + std::make_move_iterator(more_fields.end())); + + // Get KNN score for this document + float knn_score = 0.0f; + auto score_it = score_map.find(global_id.local_doc_id); + if (score_it != score_map.end()) { + knn_score = score_it->second; + } + + search::SortableValue sort_score = std::monostate{}; + + docs.push_back({std::string{key}, std::move(fields), knn_score, sort_score}); + } + + return docs; +} + +} // namespace dfly diff --git a/src/server/search/global_vector_search.h b/src/server/search/global_vector_search.h new file mode 100644 index 000000000000..2d6d68c48ba6 --- /dev/null +++ b/src/server/search/global_vector_search.h @@ -0,0 +1,85 @@ +// Copyright 2025, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#pragma once + +#include "core/search/ast_expr.h" +#include "core/search/search.h" +#include "server/search/doc_index.h" +#include "server/search/global_vector_index.h" + +namespace dfly { + +// Enhanced search result that includes global document IDs for vector results +struct GlobalSearchResult { + size_t total_hits = 0; + std::vector docs; + std::vector> knn_scores; + std::optional profile; + std::optional error; + + // Additional field for global document IDs + std::vector global_doc_ids; + + GlobalSearchResult() = default; +}; + +// Global vector search algorithm that uses the global vector index +class GlobalVectorSearchAlgorithm { + public: + GlobalVectorSearchAlgorithm() = default; + + // Initialize with query - similar to SearchAlgorithm::Init + bool Init(std::string_view query, const search::QueryParams* params); + + // Search using global vector index for vector queries, fallback to shard-based for others + GlobalSearchResult Search(std::string_view index_name, + const std::vector& shard_indices) const; + + // Check if this is a vector-only KNN query that can use global index + bool IsVectorOnlyQuery() const; + + // Get KNN sort option if present + std::optional GetKnnScoreSortOption() const; + + // Extract vector field name from KNN query + std::optional ExtractVectorFieldName() const; + + // Get KNN node for direct access (PoC helper) + const search::AstKnnNode* GetKnnNode() const; + + void EnableProfiling(); + + private: + std::unique_ptr query_; + bool profiling_enabled_ = false; + + // Check if query contains only vector search (KNN) without other filters + bool IsKnnOnlyQuery(const search::AstNode& node) const; + + // Extract vector field name from KNN query + std::optional ExtractVectorField(const search::AstNode& node) const; +}; + +// Helper class to coordinate between global vector index and shard-based document fetching +class GlobalVectorSearchCoordinator { + public: + // Execute global vector search and fetch document fields from appropriate shards + static GlobalSearchResult ExecuteGlobalVectorSearch( + std::string_view index_name, std::string_view query_str, const SearchParams& params, + const search::QueryParams& query_params, const std::vector& shard_indices); + + private: + // Group global doc IDs by shard for efficient fetching + static absl::flat_hash_map> GroupByShard( + const std::vector& global_ids); + + // Fetch document fields from specific shard + static std::vector FetchDocumentFields( + ShardId shard_id, const std::vector& shard_global_ids, + const std::vector>& knn_scores, const SearchParams& params, + ShardDocIndex* shard_index, const OpArgs& op_args); +}; + +} // namespace dfly diff --git a/src/server/search/performance_test.cc b/src/server/search/performance_test.cc new file mode 100644 index 000000000000..c4334377a801 --- /dev/null +++ b/src/server/search/performance_test.cc @@ -0,0 +1,122 @@ +// Copyright 2025, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#include +#include + +#include + +#include "base/gtest.h" +#include "facade/facade_test.h" +#include "server/search/search_family.h" +#include "server/test_utils.h" + +ABSL_DECLARE_FLAG(bool, enable_global_vector_search); + +using namespace testing; +using namespace std; +using namespace facade; + +namespace dfly { + +class PerformanceTest : public BaseFamilyTest { + protected: + void SetUp() override { + BaseFamilyTest::SetUp(); + } + + std::string FloatsToBytes(const std::vector& floats) { + return std::string(reinterpret_cast(floats.data()), floats.size() * sizeof(float)); + } + + auto TimeIt(std::function func) { + auto start = std::chrono::high_resolution_clock::now(); + func(); + auto end = std::chrono::high_resolution_clock::now(); + return std::chrono::duration_cast(end - start).count(); + } +}; + +TEST_F(PerformanceTest, GlobalVsShardBasedComparison) { + // Create index with many documents + Run({"FT.CREATE", "perf_idx", "ON", "HASH", "SCHEMA", "vec", "VECTOR", "FLAT", "6", "TYPE", + "FLOAT32", "DIM", "128", "DISTANCE_METRIC", "L2"}); + + // Add 10K documents to see global index benefits + const size_t num_docs = 10000; + for (size_t i = 0; i < num_docs; ++i) { + std::vector vec(128); + for (size_t j = 0; j < 128; ++j) { + vec[j] = static_cast(i + j) / 1000.0f; + } + + std::string key = "doc:" + std::to_string(i); + Run({"HSET", key, "vec", FloatsToBytes(vec)}); + } + + std::vector query_vec(128, 0.5f); + std::string query_bytes = FloatsToBytes(query_vec); + + // Test 1: NOCONTENT query (should be fastest with global index) + std::cout << "\n=== NOCONTENT Query Performance ===" << std::endl; + + // Shard-based search + absl::SetFlag(&FLAGS_enable_global_vector_search, false); + auto shard_time = TimeIt([&]() { + Run({"FT.SEARCH", "perf_idx", "*=>[KNN 50 @vec $qvec]", "NOCONTENT", "PARAMS", "2", "qvec", + query_bytes}); + }); + + // Global search + absl::SetFlag(&FLAGS_enable_global_vector_search, true); + auto global_time = TimeIt([&]() { + Run({"FT.SEARCH", "perf_idx", "*=>[KNN 50 @vec $qvec]", "NOCONTENT", "PARAMS", "2", "qvec", + query_bytes}); + }); + + std::cout << "Shard-based NOCONTENT: " << shard_time << " μs" << std::endl; + std::cout << "Global NOCONTENT: " << global_time << " μs" << std::endl; + std::cout << "Speedup: " << (double)shard_time / global_time << "x" << std::endl; + + // Test 2: Full query with fields + std::cout << "\n=== Full Query Performance ===" << std::endl; + + absl::SetFlag(&FLAGS_enable_global_vector_search, false); + auto shard_full_time = TimeIt([&]() { + Run({"FT.SEARCH", "perf_idx", "*=>[KNN 50 @vec $qvec]", "PARAMS", "2", "qvec", query_bytes}); + }); + + absl::SetFlag(&FLAGS_enable_global_vector_search, true); + auto global_full_time = TimeIt([&]() { + Run({"FT.SEARCH", "perf_idx", "*=>[KNN 50 @vec $qvec]", "PARAMS", "2", "qvec", query_bytes}); + }); + + std::cout << "Shard-based full: " << shard_full_time << " μs" << std::endl; + std::cout << "Global full: " << global_full_time << " μs" << std::endl; + std::cout << "Speedup: " << (double)shard_full_time / global_full_time << "x" << std::endl; + + // Test 3: With SORTBY + std::cout << "\n=== SORTBY Query Performance ===" << std::endl; + + absl::SetFlag(&FLAGS_enable_global_vector_search, false); + auto shard_sortby_time = TimeIt([&]() { + Run({"FT.SEARCH", "perf_idx", "*=>[KNN 50 @vec $qvec]", "SORTBY", "__vector_score", "PARAMS", + "2", "qvec", query_bytes}); + }); + + absl::SetFlag(&FLAGS_enable_global_vector_search, true); + auto global_sortby_time = TimeIt([&]() { + Run({"FT.SEARCH", "perf_idx", "*=>[KNN 50 @vec $qvec]", "SORTBY", "__vector_score", "PARAMS", + "2", "qvec", query_bytes}); + }); + + std::cout << "Shard-based SORTBY: " << shard_sortby_time << " μs" << std::endl; + std::cout << "Global SORTBY: " << global_sortby_time << " μs" << std::endl; + std::cout << "Speedup: " << (double)shard_sortby_time / global_sortby_time << "x" << std::endl; + + // Reset flag + absl::SetFlag(&FLAGS_enable_global_vector_search, false); +} + +} // namespace dfly diff --git a/src/server/search/search_family.cc b/src/server/search/search_family.cc index ff69a49eb871..faf9aca0a775 100644 --- a/src/server/search/search_family.cc +++ b/src/server/search/search_family.cc @@ -29,11 +29,15 @@ #include "server/container_utils.h" #include "server/engine_shard_set.h" #include "server/search/aggregator.h" +#include "server/search/doc_accessors.h" #include "server/search/doc_index.h" +#include "server/search/global_vector_search.h" #include "server/transaction.h" #include "src/core/overloaded.h" ABSL_FLAG(bool, search_reject_legacy_field, true, "FT.AGGREGATE: Reject legacy field names."); +ABSL_FLAG(bool, enable_global_vector_search, true, + "PoC: Enable global vector search for KNN queries."); namespace dfly { @@ -1029,9 +1033,31 @@ void SearchFamily::FtCreate(CmdArgList args, const CommandContext& cmd_cntx) { } auto idx_ptr = make_shared(std::move(parsed_index).value()); + + // PoC: Create global vector indices for vector fields + for (const auto& [field_ident, field_info] : idx_ptr->schema.fields) { + if (field_info.type == search::SchemaField::VECTOR && + !(field_info.flags & search::SchemaField::NOINDEX)) { + const auto& vparams = std::get(field_info.special_params); + // Use field alias (short_name) for global index key, as that's what queries use + GlobalVectorIndexRegistry::Instance().GetOrCreateVectorIndex(idx_name, field_info.short_name, + vparams); + LOG(INFO) << "Created global vector index for " << idx_name << ":" << field_info.short_name + << " (dim=" << vparams.dim << ", hnsw=" << vparams.use_hnsw << ")"; + } + } + cmd_cntx.tx->Execute( [idx_name, idx_ptr](auto* tx, auto* es) { es->search_indices()->InitIndex(tx->GetOpArgs(es), idx_name, idx_ptr); + + // PoC: Global vector indices will be populated when documents are added via HSET/JSON.SET + if (auto* index = es->search_indices()->GetIndex(idx_name); index) { + index->RebuildGlobalVectorIndices(idx_name, tx->GetOpArgs(es)); + } else { + LOG(WARNING) << "Could not find index " << idx_name << " on shard " << es->shard_id(); + } + return OpStatus::OK; }, true); @@ -1101,13 +1127,37 @@ void SearchFamily::FtDropIndex(CmdArgList args, const CommandContext& cmd_cntx) string_view idx_name = ArgS(args, 0); // TODO: Handle optional DD param + // PoC: Get index info and drop both local and global indices in single transaction + shared_ptr index_info; atomic_uint num_deleted{0}; + cmd_cntx.tx->ScheduleSingleHop([&](Transaction* t, EngineShard* es) { + // Get index info from first shard for global cleanup + if (es->shard_id() == 0) { + if (auto* idx = es->search_indices()->GetIndex(idx_name); idx != nullptr) { + index_info = make_shared(idx->GetInfo().base_index); + } + } + + // Drop local index if (es->search_indices()->DropIndex(idx_name)) num_deleted.fetch_add(1); + return OpStatus::OK; }); + // PoC: Remove global vector indices for all vector fields (after transaction) + if (index_info) { + for (const auto& [field_ident, field_info] : index_info->schema.fields) { + if (field_info.type == search::SchemaField::VECTOR && + !(field_info.flags & search::SchemaField::NOINDEX)) { + // Use field alias (short_name) for global index key, same as in FtCreate + GlobalVectorIndexRegistry::Instance().RemoveVectorIndex(idx_name, field_info.short_name); + LOG(INFO) << "Removed global vector index for " << idx_name << ":" << field_info.short_name; + } + } + } + DCHECK(num_deleted == 0u || num_deleted == shard_set->size()); if (num_deleted == 0u) return cmd_cntx.rb->SendError("-Unknown Index name"); @@ -1217,21 +1267,160 @@ void SearchFamily::FtSearch(CmdArgList args, const CommandContext& cmd_cntx) { if (!search_algo.Init(query_str, ¶ms->query_params)) return builder->SendError("Query syntax error"); - // Because our coordinator thread may not have a shard, we can't check ahead if the index exists. + // PoC: Check if we should use global vector search + bool use_global_search = false; + GlobalVectorSearchAlgorithm global_algo; + if (absl::GetFlag(FLAGS_enable_global_vector_search)) { + if (global_algo.Init(query_str, ¶ms->query_params) && global_algo.IsVectorOnlyQuery()) { + use_global_search = true; + LOG(INFO) << "Will attempt global vector search for KNN query (with SORTBY support): " + << query_str; + } + } + + // Single transaction hop - collect indices and optionally execute global search + vector shard_indices(shard_set->size(), nullptr); atomic index_not_found{false}; vector docs(shard_set->size()); + // Results from global search (if used) - per-shard buckets to avoid race conditions + std::vector> shard_global_docs(shard_set->size()); + atomic global_total_hits{0}; + atomic global_search_used{false}; + std::vector> shared_knn_results; + cmd_cntx.tx->ScheduleSingleHop([&](Transaction* t, EngineShard* es) { - if (auto* index = es->search_indices()->GetIndex(index_name); index) - docs[es->shard_id()] = index->Search(t->GetOpArgs(es), *params, &search_algo); - else + if (auto* index = es->search_indices()->GetIndex(index_name); index) { + shard_indices[es->shard_id()] = index; + + if (use_global_search) { + // Try global search from first shard only + if (es->shard_id() == 0) { + auto vector_field = global_algo.ExtractVectorFieldName(); + if (vector_field) { + auto global_index = + GlobalVectorIndexRegistry::Instance().GetVectorIndex(index_name, *vector_field); + + if (global_index && global_index->Size() > 0) { + if (auto* knn_node = global_algo.GetKnnNode()) { + auto knn_results = global_index->Knn(knn_node->vec.first.get(), knn_node->limit, + knn_node->ef_runtime); + + // Store KNN results for processing by all shards + shared_knn_results = knn_results; + + global_total_hits.store(knn_results.size()); + global_search_used.store(true); + } + } + } + } + } + + // If global search was initiated, collect documents from this shard + if (use_global_search && global_search_used.load()) { + size_t collected = 0; + for (const auto& [score, global_id] : shared_knn_results) { + if (global_id.shard_id == es->shard_id()) { + auto entry = index->LoadEntry(global_id.local_doc_id, t->GetOpArgs(es)); + if (entry) { + auto& [key, accessor] = *entry; + + SearchDocData fields{}; + auto index_info = index->GetInfo(); + if (params->ShouldReturnAllFields()) { + fields = accessor->Serialize(index_info.base_index.schema); + } + + auto return_fields = params->return_fields.value_or(std::vector{}); + auto more_fields = accessor->Serialize(index_info.base_index.schema, return_fields); + fields.insert(std::make_move_iterator(more_fields.begin()), + std::make_move_iterator(more_fields.end())); + + search::SortableValue sort_score = std::monostate{}; + + // Each shard adds to its own bucket (no race conditions) + shard_global_docs[es->shard_id()].push_back( + {std::string{key}, std::move(fields), score, sort_score}); + collected++; + } + } + } + } + + // If not using global search, execute traditional search + if (!use_global_search) { + docs[es->shard_id()] = index->Search(t->GetOpArgs(es), *params, &search_algo); + } + } else { index_not_found.store(true, memory_order_relaxed); + } return OpStatus::OK; }); if (index_not_found.load()) return builder->SendError(string{index_name} + ": no such index"); + // PoC: If global search was used, merge results from all shard buckets + if (global_search_used.load()) { + // Merge all shard buckets into single container + std::vector global_docs; + for (size_t shard_id = 0; shard_id < shard_global_docs.size(); ++shard_id) { + auto& shard_docs = shard_global_docs[shard_id]; + global_docs.insert(global_docs.end(), std::make_move_iterator(shard_docs.begin()), + std::make_move_iterator(shard_docs.end())); + } + + if (!global_docs.empty()) { + // Apply SORTBY if needed + if (params->sort_option) { + const auto& sort_opt = *params->sort_option; + auto comparator = [&sort_opt](const SerializedSearchDoc& a, const SerializedSearchDoc& b) { + std::string field_name{sort_opt.field.OutputName()}; + + auto a_it = a.values.find(field_name); + auto b_it = b.values.find(field_name); + + if (a_it == a.values.end() && b_it == b.values.end()) + return false; + if (a_it == a.values.end()) + return false; + if (b_it == b.values.end()) + return true; + + bool result = a_it->second < b_it->second; + return sort_opt.order == SortOrder::DESC ? !result : result; + }; + + std::sort(global_docs.begin(), global_docs.end(), comparator); + } + + // Apply LIMIT + size_t start_idx = std::min(params->limit_offset, global_docs.size()); + size_t end_idx = std::min(start_idx + params->limit_total, global_docs.size()); + + if (start_idx > 0 || end_idx < global_docs.size()) { + std::vector limited_docs; + for (size_t i = start_idx; i < end_idx; ++i) { + limited_docs.push_back(std::move(global_docs[i])); + } + global_docs = std::move(limited_docs); + } + + vector shard_results(1); + shard_results[0].total_hits = global_total_hits.load(); + shard_results[0].docs = std::move(global_docs); + + SearchReply(*params, global_algo.GetKnnScoreSortOption(), absl::MakeSpan(shard_results), + builder); + return; + } else { + LOG(WARNING) << "Global search used but no docs collected"; + } + } + + // Traditional shard-based results (if global search not used) + for (const auto& res : docs) { if (res.error) return builder->SendError(*res.error); From f4e1fad7746490fd7c8632db4a032752d583e34b Mon Sep 17 00:00:00 2001 From: Volodymyr Yavdoshenko Date: Sun, 7 Sep 2025 21:12:39 +0300 Subject: [PATCH 2/2] feat: add search poc --- src/server/search/doc_index.cc | 100 ++----- src/server/search/global_vector_search.cc | 39 +-- src/server/search/global_vector_search.h | 17 +- src/server/search/search_family.cc | 315 ++++++++++++---------- 4 files changed, 237 insertions(+), 234 deletions(-) diff --git a/src/server/search/doc_index.cc b/src/server/search/doc_index.cc index f287b58484ff..33775ac0fabc 100644 --- a/src/server/search/doc_index.cc +++ b/src/server/search/doc_index.cc @@ -728,31 +728,21 @@ void ShardDocIndex::AddDocToGlobalVectorIndex(std::string_view index_name, std:: return; auto accessor = GetAccessor(db_cntx, pv); + auto local_id = key_index_.Find(key); + if (!local_id) + return; + + GlobalDocId global_id{EngineShard::tlocal()->shard_id(), *local_id}; - // Check if document has vector fields and add them to global index for (const auto& [field_ident, field_info] : base_->schema.fields) { if (field_info.type == search::SchemaField::VECTOR && !(field_info.flags & search::SchemaField::NOINDEX)) { - auto vector_info = accessor->GetVector(field_ident); - if (vector_info && vector_info->first) { - // Get or create global vector index + if (auto vector_info = accessor->GetVector(field_ident); vector_info && vector_info->first) { const auto& vparams = std::get(field_info.special_params); auto global_index = GlobalVectorIndexRegistry::Instance().GetOrCreateVectorIndex( index_name, field_info.short_name, vparams); - - // Find local DocId for this key - auto local_id = key_index_.Find(key); - if (local_id) { - GlobalDocId global_id{EngineShard::tlocal()->shard_id(), *local_id}; - - // Add vector to global index - global_index->AddVector(global_id, key, vector_info->first.get()); - - LOG(INFO) << "Added vector to global index: " << index_name << ":" - << field_info.short_name << " key=" << key << " global_id={" - << global_id.shard_id << "," << global_id.local_doc_id << "}"; - } + global_index->AddVector(global_id, key, vector_info->first.get()); } } } @@ -764,89 +754,49 @@ void ShardDocIndex::RemoveDocFromGlobalVectorIndex(std::string_view index_name, if (!indices_) return; - // Check if document has vector fields and remove them from global index + auto local_id = key_index_.Find(key); + if (!local_id) + return; + + GlobalDocId global_id{EngineShard::tlocal()->shard_id(), *local_id}; + for (const auto& [field_ident, field_info] : base_->schema.fields) { if (field_info.type == search::SchemaField::VECTOR && !(field_info.flags & search::SchemaField::NOINDEX)) { - auto global_index = - GlobalVectorIndexRegistry::Instance().GetVectorIndex(index_name, field_info.short_name); - - if (global_index) { - // Find local DocId for this key - auto local_id = key_index_.Find(key); - if (local_id) { - GlobalDocId global_id{EngineShard::tlocal()->shard_id(), *local_id}; - - // Remove vector from global index - global_index->RemoveVector(global_id, key); - - VLOG(1) << "Removed vector from global index: " << index_name << ":" - << field_info.short_name << " key=" << key << " global_id={" << global_id.shard_id - << "," << global_id.local_doc_id << "}"; - } + if (auto global_index = GlobalVectorIndexRegistry::Instance().GetVectorIndex( + index_name, field_info.short_name)) { + global_index->RemoveVector(global_id, key); } } } } void ShardDocIndex::RebuildGlobalVectorIndices(std::string_view index_name, const OpArgs& op_args) { - if (!indices_) { - LOG(WARNING) << "No indices available for " << index_name; + if (!indices_) return; - } - // Traverse all documents and rebuild global vector indices - size_t vectors_added = 0; - auto cb = [this, index_name, &vectors_added](string_view key, const BaseAccessor& doc) { - LOG(INFO) << "Processing document: " << key; + auto cb = [this, index_name](string_view key, const BaseAccessor& doc) { + auto local_id = key_index_.Find(key); + if (!local_id) + return; - // Check if document has vector fields and add them to global index - for (const auto& [field_ident, field_info] : base_->schema.fields) { - LOG(INFO) << "Checking field: " << field_ident << " (type=" << field_info.type - << ", alias=" << field_info.short_name << ")"; + GlobalDocId global_id{EngineShard::tlocal()->shard_id(), *local_id}; + for (const auto& [field_ident, field_info] : base_->schema.fields) { if (field_info.type == search::SchemaField::VECTOR && !(field_info.flags & search::SchemaField::NOINDEX)) { - LOG(INFO) << "Found vector field: " << field_ident << " -> " << field_info.short_name; - - auto vector_info = doc.GetVector(field_ident); - if (vector_info && vector_info->first) { - LOG(INFO) << "Vector data found for field: " << field_ident; - - // Get or create global vector index + if (auto vector_info = doc.GetVector(field_ident); vector_info && vector_info->first) { const auto& vparams = std::get(field_info.special_params); auto global_index = GlobalVectorIndexRegistry::Instance().GetOrCreateVectorIndex( index_name, field_info.short_name, vparams); - - // Find local DocId for this key - auto local_id = key_index_.Find(key); - if (local_id) { - GlobalDocId global_id{EngineShard::tlocal()->shard_id(), *local_id}; - - // Add vector to global index - if (global_index->AddVector(global_id, key, vector_info->first.get())) { - vectors_added++; - LOG(INFO) << "Successfully added vector to global index: " << index_name << ":" - << field_info.short_name << " key=" << key << " global_id={" - << global_id.shard_id << "," << global_id.local_doc_id << "}"; - } else { - LOG(WARNING) << "Failed to add vector to global index: " << key; - } - } else { - LOG(WARNING) << "Could not find local DocId for key: " << key; - } - } else { - LOG(INFO) << "No vector data for field: " << field_ident << " in key: " << key; + global_index->AddVector(global_id, key, vector_info->first.get()); } } } }; TraverseAllMatching(*base_, op_args, cb); - - LOG(INFO) << "Rebuilt global vector indices for " << index_name << " with " << key_index_.Size() - << " docs on shard " << EngineShard::tlocal()->shard_id(); } } // namespace dfly diff --git a/src/server/search/global_vector_search.cc b/src/server/search/global_vector_search.cc index d5c27a6aa916..9995bd62367d 100644 --- a/src/server/search/global_vector_search.cc +++ b/src/server/search/global_vector_search.cc @@ -80,15 +80,7 @@ GlobalSearchResult GlobalVectorSearchAlgorithm::Search( GlobalSearchResult result; result.total_hits = knn_results.size(); - result.global_doc_ids.reserve(knn_results.size()); - - // Convert to search result format - result.knn_scores.reserve(knn_results.size()); - - for (const auto& [score, global_id] : knn_results) { - result.global_doc_ids.push_back(global_id); - result.knn_scores.emplace_back(global_id.local_doc_id, score); // temp mapping - } + result.knn_results = std::move(knn_results); LOG(INFO) << "Global vector search found " << result.total_hits << " results for index " << index_name << " field " << *vector_field; @@ -117,8 +109,17 @@ std::optional GlobalVectorSearchAlgorithm::ExtractVectorFieldName() return std::nullopt; } -const search::AstKnnNode* GlobalVectorSearchAlgorithm::GetKnnNode() const { - return std::get_if(query_.get()); +std::optional GlobalVectorSearchAlgorithm::GetKnnParams() + const { + if (auto* knn_node = std::get_if(query_.get())) { + std::optional ef_runtime = std::nullopt; + if (knn_node->ef_runtime.has_value()) { + ef_runtime = static_cast(knn_node->ef_runtime.value()); + } + return KnnParams{ + .vector = knn_node->vec.first.get(), .limit = knn_node->limit, .ef_runtime = ef_runtime}; + } + return std::nullopt; } void GlobalVectorSearchAlgorithm::EnableProfiling() { @@ -158,12 +159,20 @@ GlobalSearchResult GlobalVectorSearchCoordinator::ExecuteGlobalVectorSearch( // Execute global search auto global_result = algo.Search(index_name, shard_indices); - if (global_result.global_doc_ids.empty()) { + if (global_result.knn_results.empty()) { return global_result; } + // Extract global doc IDs from results and create knn_scores mapping + std::vector global_doc_ids; + std::vector> knn_scores; + for (const auto& [score, global_id] : global_result.knn_results) { + global_doc_ids.push_back(global_id); + knn_scores.emplace_back(global_id.local_doc_id, score); + } + // Group global doc IDs by shard - auto grouped_ids = GroupByShard(global_result.global_doc_ids); + auto grouped_ids = GroupByShard(global_doc_ids); // Fetch document fields from each shard std::vector all_docs; @@ -175,8 +184,8 @@ GlobalSearchResult GlobalVectorSearchCoordinator::ExecuteGlobalVectorSearch( DbContext db_cntx{&namespaces->GetDefaultNamespace(), 0, GetCurrentTimeMs()}; OpArgs op_args{EngineShard::tlocal(), nullptr, db_cntx}; - auto shard_docs = FetchDocumentFields(shard_id, shard_global_ids, global_result.knn_scores, - params, shard_indices[shard_id], op_args); + auto shard_docs = FetchDocumentFields(shard_id, shard_global_ids, knn_scores, params, + shard_indices[shard_id], op_args); all_docs.insert(all_docs.end(), std::make_move_iterator(shard_docs.begin()), std::make_move_iterator(shard_docs.end())); diff --git a/src/server/search/global_vector_search.h b/src/server/search/global_vector_search.h index 2d6d68c48ba6..7a14499885c4 100644 --- a/src/server/search/global_vector_search.h +++ b/src/server/search/global_vector_search.h @@ -11,17 +11,13 @@ namespace dfly { -// Enhanced search result that includes global document IDs for vector results +// Global vector search result with global document IDs struct GlobalSearchResult { size_t total_hits = 0; + std::vector> knn_results; std::vector docs; - std::vector> knn_scores; - std::optional profile; std::optional error; - // Additional field for global document IDs - std::vector global_doc_ids; - GlobalSearchResult() = default; }; @@ -46,8 +42,13 @@ class GlobalVectorSearchAlgorithm { // Extract vector field name from KNN query std::optional ExtractVectorFieldName() const; - // Get KNN node for direct access (PoC helper) - const search::AstKnnNode* GetKnnNode() const; + // Get KNN parameters for global search + struct KnnParams { + float* vector; + size_t limit; + std::optional ef_runtime; + }; + std::optional GetKnnParams() const; void EnableProfiling(); diff --git a/src/server/search/search_family.cc b/src/server/search/search_family.cc index faf9aca0a775..77f9b2a24930 100644 --- a/src/server/search/search_family.cc +++ b/src/server/search/search_family.cc @@ -42,6 +42,7 @@ ABSL_FLAG(bool, enable_global_vector_search, true, namespace dfly { using namespace std; + using namespace facade; namespace { @@ -985,6 +986,166 @@ void SearchReply(const SearchParams& params, } } +// Optimized global vector search execution +static void ExecuteGlobalVectorSearch(string_view index_name, string_view query_str, + const SearchParams& params, + const GlobalVectorSearchAlgorithm& global_algo, + const CommandContext& cmd_cntx) { + auto* builder = cmd_cntx.rb; + + auto vector_field = global_algo.ExtractVectorFieldName(); + if (!vector_field) { + return builder->SendError("Could not extract vector field from KNN query"); + } + + auto global_index = + GlobalVectorIndexRegistry::Instance().GetVectorIndex(index_name, *vector_field); + if (!global_index || global_index->Size() == 0) { + LOG(INFO) << "Global index not available for " << index_name << ":" << *vector_field + << ", falling back to shard-based search"; + return builder->SendError("Global index not available and fallback not implemented"); + } + + auto knn_params = global_algo.GetKnnParams(); + if (!knn_params) { + return builder->SendError("Could not extract KNN parameters"); + } + + // Execute global KNN search + auto knn_results = + global_index->Knn(knn_params->vector, knn_params->limit, knn_params->ef_runtime); + + if (knn_results.empty()) { + vector empty_results(1); + empty_results[0].total_hits = 0; + SearchReply(params, global_algo.GetKnnScoreSortOption(), absl::MakeSpan(empty_results), + builder); + return; + } + + // Streamlined approach: minimal containers and operations + vector global_docs; + global_docs.reserve(knn_results.size()); + + // Group by shard with minimal allocations + vector>> shard_doc_ids(shard_set->size()); + for (const auto& [score, global_id] : knn_results) { + shard_doc_ids[global_id.shard_id].emplace_back(score, global_id.local_doc_id); + } + + // Use per-shard vectors to avoid race conditions, but keep them minimal + vector> shard_docs(shard_set->size()); + atomic index_not_found{false}; + + cmd_cntx.tx->ScheduleSingleHop([&](Transaction* t, EngineShard* es) { + auto* index = es->search_indices()->GetIndex(index_name); + if (!index) { + index_not_found.store(true); + return OpStatus::OK; + } + + auto& shard_requests = shard_doc_ids[es->shard_id()]; + if (shard_requests.empty()) + return OpStatus::OK; + + auto& docs_for_shard = shard_docs[es->shard_id()]; + docs_for_shard.reserve(shard_requests.size()); + + // Cache schema reference to avoid repeated lookups + const auto& schema = index->GetInfo().base_index.schema; + + // Optimize serialization based on query type + if (params.ShouldReturnAllFields()) { + // Full serialization for full queries + for (const auto& [score, doc_id] : shard_requests) { + if (auto entry = index->LoadEntry(doc_id, t->GetOpArgs(es))) { + auto& [key, accessor] = *entry; + auto fields = accessor->Serialize(schema); + docs_for_shard.push_back({string{key}, std::move(fields), score, std::monostate{}}); + } + } + } else { + // Selective field serialization + const auto& return_fields = params.return_fields.value_or(vector{}); + for (const auto& [score, doc_id] : shard_requests) { + if (auto entry = index->LoadEntry(doc_id, t->GetOpArgs(es))) { + auto& [key, accessor] = *entry; + auto fields = return_fields.empty() ? SearchDocData{} + // NOCONTENT query - no fields needed + : accessor->Serialize(schema, return_fields); + docs_for_shard.push_back({string{key}, std::move(fields), score, std::monostate{}}); + } + } + } + + return OpStatus::OK; + }); + + if (index_not_found.load()) { + return builder->SendError(string{index_name} + ": no such index"); + } + + // Efficient merge with size hint + size_t total_docs = 0; + for (const auto& docs : shard_docs) { + total_docs += docs.size(); + } + global_docs.reserve(total_docs); + + for (auto& docs : shard_docs) { + global_docs.insert(global_docs.end(), std::make_move_iterator(docs.begin()), + std::make_move_iterator(docs.end())); + } + + // Only sort if there's an explicit non-score sort option + if (params.sort_option && !global_docs.empty()) { + const auto& sort_opt = *params.sort_option; + // Skip sorting if it's score-based (results are already sorted by KNN) + string_view field_name = sort_opt.field.OutputName(); + if (field_name != "_score" && field_name != "score") { + sort(global_docs.begin(), global_docs.end(), + [&sort_opt](const SerializedSearchDoc& a, const SerializedSearchDoc& b) { + auto field_name = string{sort_opt.field.OutputName()}; + auto a_it = a.values.find(field_name); + auto b_it = b.values.find(field_name); + + if (a_it == a.values.end()) + return false; + if (b_it == b.values.end()) + return true; + + bool result = a_it->second < b_it->second; + return sort_opt.order == SortOrder::DESC ? !result : result; + }); + } + } + + // Apply LIMIT efficiently + if (params.limit_offset > 0 || params.limit_total < global_docs.size()) { + size_t start_idx = min(params.limit_offset, global_docs.size()); + size_t end_idx = min(start_idx + params.limit_total, global_docs.size()); + + if (start_idx >= global_docs.size()) { + global_docs.clear(); + } else { + // Erase from end first to avoid shifting elements + if (end_idx < global_docs.size()) { + global_docs.erase(global_docs.begin() + end_idx, global_docs.end()); + } + if (start_idx > 0) { + global_docs.erase(global_docs.begin(), global_docs.begin() + start_idx); + } + } + } + + vector shard_results_final(1); + shard_results_final[0].total_hits = knn_results.size(); + shard_results_final[0].docs = std::move(global_docs); + + SearchReply(params, global_algo.GetKnnScoreSortOption(), absl::MakeSpan(shard_results_final), + builder); +} + // Warms up the query parser to avoid first-call slowness void WarmupQueryParser() { static std::once_flag warmed_up; @@ -1263,95 +1424,35 @@ void SearchFamily::FtSearch(CmdArgList args, const CommandContext& cmd_cntx) { if (SendErrorIfOccurred(params, &parser, builder)) return; + // Parse query once and determine search strategy + GlobalVectorSearchAlgorithm global_algo; search::SearchAlgorithm search_algo; - if (!search_algo.Init(query_str, ¶ms->query_params)) + + if (!global_algo.Init(query_str, ¶ms->query_params)) return builder->SendError("Query syntax error"); - // PoC: Check if we should use global vector search - bool use_global_search = false; - GlobalVectorSearchAlgorithm global_algo; - if (absl::GetFlag(FLAGS_enable_global_vector_search)) { - if (global_algo.Init(query_str, ¶ms->query_params) && global_algo.IsVectorOnlyQuery()) { - use_global_search = true; - LOG(INFO) << "Will attempt global vector search for KNN query (with SORTBY support): " - << query_str; - } + bool use_global_search = + absl::GetFlag(FLAGS_enable_global_vector_search) && global_algo.IsVectorOnlyQuery(); + + // Only init shard-based algorithm if needed + if (!use_global_search && !search_algo.Init(query_str, ¶ms->query_params)) + return builder->SendError("Query syntax error"); + + // Early exit for global search if possible + if (use_global_search) { + return ExecuteGlobalVectorSearch(index_name, query_str, *params, global_algo, cmd_cntx); } - // Single transaction hop - collect indices and optionally execute global search + // Shard-based search setup vector shard_indices(shard_set->size(), nullptr); atomic index_not_found{false}; vector docs(shard_set->size()); - // Results from global search (if used) - per-shard buckets to avoid race conditions - std::vector> shard_global_docs(shard_set->size()); - atomic global_total_hits{0}; - atomic global_search_used{false}; - std::vector> shared_knn_results; - cmd_cntx.tx->ScheduleSingleHop([&](Transaction* t, EngineShard* es) { if (auto* index = es->search_indices()->GetIndex(index_name); index) { shard_indices[es->shard_id()] = index; - - if (use_global_search) { - // Try global search from first shard only - if (es->shard_id() == 0) { - auto vector_field = global_algo.ExtractVectorFieldName(); - if (vector_field) { - auto global_index = - GlobalVectorIndexRegistry::Instance().GetVectorIndex(index_name, *vector_field); - - if (global_index && global_index->Size() > 0) { - if (auto* knn_node = global_algo.GetKnnNode()) { - auto knn_results = global_index->Knn(knn_node->vec.first.get(), knn_node->limit, - knn_node->ef_runtime); - - // Store KNN results for processing by all shards - shared_knn_results = knn_results; - - global_total_hits.store(knn_results.size()); - global_search_used.store(true); - } - } - } - } - } - - // If global search was initiated, collect documents from this shard - if (use_global_search && global_search_used.load()) { - size_t collected = 0; - for (const auto& [score, global_id] : shared_knn_results) { - if (global_id.shard_id == es->shard_id()) { - auto entry = index->LoadEntry(global_id.local_doc_id, t->GetOpArgs(es)); - if (entry) { - auto& [key, accessor] = *entry; - - SearchDocData fields{}; - auto index_info = index->GetInfo(); - if (params->ShouldReturnAllFields()) { - fields = accessor->Serialize(index_info.base_index.schema); - } - - auto return_fields = params->return_fields.value_or(std::vector{}); - auto more_fields = accessor->Serialize(index_info.base_index.schema, return_fields); - fields.insert(std::make_move_iterator(more_fields.begin()), - std::make_move_iterator(more_fields.end())); - - search::SortableValue sort_score = std::monostate{}; - - // Each shard adds to its own bucket (no race conditions) - shard_global_docs[es->shard_id()].push_back( - {std::string{key}, std::move(fields), score, sort_score}); - collected++; - } - } - } - } - - // If not using global search, execute traditional search - if (!use_global_search) { - docs[es->shard_id()] = index->Search(t->GetOpArgs(es), *params, &search_algo); - } + // Execute traditional shard-based search + docs[es->shard_id()] = index->Search(t->GetOpArgs(es), *params, &search_algo); } else { index_not_found.store(true, memory_order_relaxed); } @@ -1361,65 +1462,7 @@ void SearchFamily::FtSearch(CmdArgList args, const CommandContext& cmd_cntx) { if (index_not_found.load()) return builder->SendError(string{index_name} + ": no such index"); - // PoC: If global search was used, merge results from all shard buckets - if (global_search_used.load()) { - // Merge all shard buckets into single container - std::vector global_docs; - for (size_t shard_id = 0; shard_id < shard_global_docs.size(); ++shard_id) { - auto& shard_docs = shard_global_docs[shard_id]; - global_docs.insert(global_docs.end(), std::make_move_iterator(shard_docs.begin()), - std::make_move_iterator(shard_docs.end())); - } - - if (!global_docs.empty()) { - // Apply SORTBY if needed - if (params->sort_option) { - const auto& sort_opt = *params->sort_option; - auto comparator = [&sort_opt](const SerializedSearchDoc& a, const SerializedSearchDoc& b) { - std::string field_name{sort_opt.field.OutputName()}; - - auto a_it = a.values.find(field_name); - auto b_it = b.values.find(field_name); - - if (a_it == a.values.end() && b_it == b.values.end()) - return false; - if (a_it == a.values.end()) - return false; - if (b_it == b.values.end()) - return true; - - bool result = a_it->second < b_it->second; - return sort_opt.order == SortOrder::DESC ? !result : result; - }; - - std::sort(global_docs.begin(), global_docs.end(), comparator); - } - - // Apply LIMIT - size_t start_idx = std::min(params->limit_offset, global_docs.size()); - size_t end_idx = std::min(start_idx + params->limit_total, global_docs.size()); - - if (start_idx > 0 || end_idx < global_docs.size()) { - std::vector limited_docs; - for (size_t i = start_idx; i < end_idx; ++i) { - limited_docs.push_back(std::move(global_docs[i])); - } - global_docs = std::move(limited_docs); - } - - vector shard_results(1); - shard_results[0].total_hits = global_total_hits.load(); - shard_results[0].docs = std::move(global_docs); - - SearchReply(*params, global_algo.GetKnnScoreSortOption(), absl::MakeSpan(shard_results), - builder); - return; - } else { - LOG(WARNING) << "Global search used but no docs collected"; - } - } - - // Traditional shard-based results (if global search not used) + // Process traditional shard-based results for (const auto& res : docs) { if (res.error)