Skip to content

Commit 467f0fc

Browse files
committed
WIP 2
1 parent e65709a commit 467f0fc

File tree

7 files changed

+75
-23
lines changed

7 files changed

+75
-23
lines changed

src/core/search/indices.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,14 +509,15 @@ FlatVectorIndex::FlatVectorIndex(const SchemaField::VectorParams& params, ShardI
509509
DCHECK(!params.use_hnsw);
510510
entries_.resize(shard_set_size);
511511
for (size_t i = 0; i < shard_set_size; i++) {
512-
entries_[i].resize(params.capacity * params.dim);
512+
entries_[i].reserve(params.capacity * params.dim);
513513
}
514514
}
515515

516516
void FlatVectorIndex::AddVector(GlobalDocId id,
517517
const typename BaseVectorIndex<GlobalDocId>::VectorPtr& vector) {
518518
auto shard_id = search::GlobalDocIdShardId(id);
519519
auto shard_doc_id = search::GlobalDocIdLocalId(id);
520+
DCHECK_LE(shard_doc_id * BaseVectorIndex<GlobalDocId>::dim_, entries_[shard_id].size());
520521
if (shard_doc_id * BaseVectorIndex<GlobalDocId>::dim_ == entries_[shard_id].size()) {
521522
unique_lock<util::fb2::SharedMutex> lock{shard_vector_locks_[shard_id]};
522523
entries_[shard_id].resize((shard_doc_id + 1) * BaseVectorIndex<GlobalDocId>::dim_);

src/core/search/search.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ class SearchAlgorithm {
202202
SearchResult Search(const FieldIndices* index,
203203
size_t cuttoff_limit = std::numeric_limits<size_t>::max()) const;
204204

205-
std::optional<KnnScoreSortOption> GetKnnScoreSortOption() const {
205+
const std::optional<KnnScoreSortOption>& GetKnnScoreSortOption() const {
206206
return knn_score_sort_option_;
207207
}
208208

src/server/main_service.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ extern "C" {
5959
#include "server/multi_command_squasher.h"
6060
#include "server/namespaces.h"
6161
#include "server/script_mgr.h"
62+
#include "server/search/global_vector_index.h"
6263
#include "server/search/search_family.h"
6364
#include "server/server_state.h"
6465
#include "server/set_family.h"
@@ -1123,6 +1124,11 @@ void Service::Shutdown() {
11231124

11241125
shard_set->PreShutdown();
11251126
shard_set->Shutdown();
1127+
1128+
#ifdef WITH_SEARCH
1129+
GlobalVectorIndexRegistry::Instance().Reset();
1130+
#endif
1131+
11261132
Transaction::Shutdown();
11271133

11281134
pp_.AwaitFiberOnAll([](ProactorBase* pb) { ServerState::tlocal()->Destroy(); });

src/server/search/doc_index.cc

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <memory>
1010
#include <optional>
1111
#include <queue>
12+
#include <string_view>
1213

1314
#include "absl/strings/str_cat.h"
1415
#include "base/logging.h"
@@ -382,11 +383,9 @@ void ShardDocIndex::AddDocToGlobalVectorIndex(std::string_view index_name,
382383
for (const auto& [field_ident, field_info] : base_->schema.fields) {
383384
if (field_info.type == search::SchemaField::VECTOR &&
384385
!(field_info.flags & search::SchemaField::NOINDEX)) {
385-
if (auto vector_info = accessor->GetVector(field_ident); vector_info && vector_info->first) {
386-
auto global_index =
387-
GlobalVectorIndexRegistry::Instance().GetVectorIndex(index_name, field_info.short_name);
388-
global_index->Add(global_id, *accessor, field_ident);
389-
}
386+
GlobalVectorIndexRegistry::Instance()
387+
.GetVectorIndex(index_name, field_info.short_name)
388+
->Add(global_id, *accessor, field_ident);
390389
}
391390
}
392391
}
@@ -411,18 +410,18 @@ void ShardDocIndex::RemoveDocFromGlobalVectorIndex(std::string_view index_name,
411410
void ShardDocIndex::RebuildGlobalVectorIndices(std::string_view index_name, const OpArgs& op_args) {
412411
auto cb = [this, index_name](string_view key, const BaseAccessor& doc) {
413412
auto local_id = key_index_.Find(key);
413+
414414
if (!local_id)
415415
return;
416+
416417
GlobalDocId global_id = search::CreateGlobalDocId(EngineShard::tlocal()->shard_id(), *local_id);
417418

418419
for (const auto& [field_ident, field_info] : base_->schema.fields) {
419420
if (field_info.type == search::SchemaField::VECTOR &&
420421
!(field_info.flags & search::SchemaField::NOINDEX)) {
421-
if (auto vector_info = doc.GetVector(field_ident); vector_info && vector_info->first) {
422-
auto global_index = GlobalVectorIndexRegistry::Instance().GetVectorIndex(
423-
index_name, field_info.short_name);
424-
global_index->Add(global_id, doc, field_ident);
425-
}
422+
GlobalVectorIndexRegistry::Instance()
423+
.GetVectorIndex(index_name, field_info.short_name)
424+
->Add(global_id, doc, field_ident);
426425
}
427426
}
428427
};

src/server/search/global_vector_index.cc

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,10 @@ std::vector<std::pair<float, search::GlobalDocId>> GlobalVectorIndex::SearchKnnF
6666
return index->Knn(knn->vec.first.get());
6767
}
6868

69-
std::vector<SearchResult> GlobalVectorIndex::Search(const search::AstKnnNode* knn_node,
70-
const std::vector<SearchResult>& filter_docs,
71-
const CommandContext& cmd_cntx,
72-
const SearchParams& params) {
69+
std::vector<SearchResult> GlobalVectorIndex::Search(
70+
const search::AstKnnNode* knn_node, const std::vector<SearchResult>& filter_docs,
71+
const std::optional<search::KnnScoreSortOption>& knn_score_option,
72+
const CommandContext& cmd_cntx, const SearchParams& params) {
7373
std::vector<SearchResult> results(1);
7474

7575
std::optional<std::vector<search::GlobalDocId>> filter_docs_global_ids = std::nullopt;
@@ -105,7 +105,8 @@ std::vector<SearchResult> GlobalVectorIndex::Search(const search::AstKnnNode* kn
105105
return results;
106106
}
107107

108-
std::vector<SerializedSearchDoc> knn_result_docs(knn_results.size());
108+
std::vector<SerializedSearchDoc> knn_result_docs;
109+
knn_result_docs.reserve(knn_results.size());
109110

110111
// Group by shard with minimal allocations
111112
std::vector<std::vector<std::pair<float, search::DocId>>> shard_doc_ids(shard_size);
@@ -148,28 +149,56 @@ std::vector<SearchResult> GlobalVectorIndex::Search(const search::AstKnnNode* kn
148149
// Cache schema reference to avoid repeated lookups
149150
const auto& schema = index->GetInfo().base_index.schema;
150151

152+
auto& sort_option = params.sort_option;
153+
bool need_fetch_sort_field = false;
154+
155+
if (sort_option) {
156+
need_fetch_sort_field = !params.sort_option->IsSame(*knn_score_option);
157+
}
158+
151159
// Optimize serialization based on query type
152160
if (params.ShouldReturnAllFields()) {
153161
// Full serialization for full queries
154162
for (const auto& [score, doc_id] : shard_requests) {
155163
if (auto entry = index->LoadEntry(doc_id, t->GetOpArgs(es))) {
156164
auto& [key, accessor] = *entry;
157165
auto fields = accessor->Serialize(schema);
166+
167+
search::SortableValue sort_score = std::monostate{};
168+
169+
if (need_fetch_sort_field) {
170+
sort_score = fields[params.sort_option->field.Name()];
171+
}
172+
158173
docs_for_shard.push_back(
159-
{doc_id, std::string{key}, std::move(fields), score, std::monostate{}});
174+
{doc_id, std::string{key}, std::move(fields), score, sort_score});
160175
}
161176
}
162177
} else {
163178
// Selective field serialization
164179
const auto& return_fields = params.return_fields.value_or(std::vector<FieldReference>{});
180+
bool sort_field_in_return_fields = false;
181+
182+
if (need_fetch_sort_field) {
183+
for (auto& field_reference : return_fields) {
184+
sort_field_in_return_fields = field_reference.Name() == sort_option->field.Name();
185+
}
186+
}
187+
165188
for (const auto& [score, doc_id] : shard_requests) {
166189
if (auto entry = index->LoadEntry(doc_id, t->GetOpArgs(es))) {
167190
auto& [key, accessor] = *entry;
168191
auto fields = return_fields.empty() ? SearchDocData{}
169192
// NOCONTENT query - no fields needed
170193
: accessor->Serialize(schema, return_fields);
194+
195+
search::SortableValue sort_score = std::monostate{};
196+
if (sort_field_in_return_fields && need_fetch_sort_field) {
197+
sort_score = fields[params.sort_option->field.Name()];
198+
}
199+
171200
docs_for_shard.push_back(
172-
{doc_id, std::string{key}, std::move(fields), score, std::monostate{}});
201+
{doc_id, std::string{key}, std::move(fields), score, sort_score});
173202
}
174203
}
175204
}
@@ -182,6 +211,12 @@ std::vector<SearchResult> GlobalVectorIndex::Search(const search::AstKnnNode* kn
182211
std::make_move_iterator(docs.end()));
183212
}
184213

214+
auto cb = [](SerializedSearchDoc& l, SerializedSearchDoc& r) {
215+
return l.knn_score > r.knn_score;
216+
};
217+
partial_sort(knn_result_docs.begin(), knn_result_docs.begin() + knn_result_docs.size() / 2,
218+
knn_result_docs.end(), cb);
219+
185220
results[0].total_hits = knn_results.size();
186221
results[0].docs = std::move(knn_result_docs);
187222
return results;
@@ -230,6 +265,11 @@ std::shared_ptr<GlobalVectorIndex> GlobalVectorIndexRegistry::GetVectorIndex(
230265
return it != indices_.end() ? it->second : nullptr;
231266
}
232267

268+
void GlobalVectorIndexRegistry::Reset() {
269+
std::unique_lock<std::shared_mutex> lock(registry_mutex_);
270+
indices_.clear();
271+
}
272+
233273
std::string GlobalVectorIndexRegistry::MakeKey(std::string_view index_name,
234274
std::string_view field_name) const {
235275
return absl::StrCat(index_name, ":", field_name);

src/server/search/global_vector_index.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
namespace dfly {
2323

24+
struct KnnScoreSortOption;
25+
2426
class GlobalVectorIndex {
2527
public:
2628
GlobalVectorIndex(const search::SchemaField::VectorParams& params, std::string_view index_name,
@@ -31,9 +33,10 @@ class GlobalVectorIndex {
3133
bool Add(search::GlobalDocId id, const search::DocumentAccessor& doc, std::string_view field);
3234
void Remove(search::GlobalDocId id, const search::DocumentAccessor& doc, std::string_view field);
3335

34-
std::vector<SearchResult> Search(const search::AstKnnNode* knn,
35-
const std::vector<SearchResult>& filter_docs,
36-
const CommandContext& cmd_cntx, const SearchParams& params);
36+
std::vector<SearchResult> Search(
37+
const search::AstKnnNode* knn, const std::vector<SearchResult>& filter_docs,
38+
const std::optional<search::KnnScoreSortOption>& knn_score_option,
39+
const CommandContext& cmd_cntx, const SearchParams& params);
3740

3841
std::pair<size_t, search::VectorSimilarity> Info() const;
3942

@@ -70,6 +73,9 @@ class GlobalVectorIndexRegistry {
7073
std::shared_ptr<GlobalVectorIndex> GetVectorIndex(std::string_view index_name,
7174
std::string_view field_name) const;
7275

76+
// Reset all vector indices
77+
void Reset();
78+
7379
private:
7480
GlobalVectorIndexRegistry() = default;
7581

src/server/search/search_family.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1394,7 +1394,7 @@ void SearchFamily::FtSearch(CmdArgList args, const CommandContext& cmd_cntx) {
13941394
if (!vector_index)
13951395
return builder->SendError(string{index_name} + ": no such index");
13961396

1397-
docs = vector_index->Search(knn, docs, cmd_cntx, *params);
1397+
docs = vector_index->Search(knn, docs, search_algo.GetKnnScoreSortOption(), cmd_cntx, *params);
13981398
}
13991399

14001400
SearchReply(*params, search_algo.GetKnnScoreSortOption(), absl::MakeSpan(docs), builder);

0 commit comments

Comments
 (0)