Skip to content

Commit cf80534

Browse files
committed
feat: add HSNW index deserialization
1 parent b7e5f93 commit cf80534

File tree

7 files changed

+294
-19
lines changed

7 files changed

+294
-19
lines changed

src/core/search/hnsw_index.cc

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,19 @@ struct HnswlibAdapter {
133133
return metadata;
134134
}
135135

136+
void SetMetadata(const HnswIndexMetadata& metadata) {
137+
MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kWriteLock);
138+
if (world_.max_elements_ < metadata.max_elements) {
139+
world_.resizeIndex(metadata.max_elements);
140+
}
141+
world_.cur_element_count.store(metadata.cur_element_count);
142+
world_.maxlevel_ = metadata.maxlevel;
143+
world_.enterpoint_node_ = metadata.enterpoint_node;
144+
world_.ef_construction_ = metadata.ef_construction;
145+
world_.mult_ = metadata.mult;
146+
// Note: M, maxM, maxM0 are set at construction and shouldn't change
147+
}
148+
136149
size_t GetNodeCount() const {
137150
MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kReadLock);
138151
return world_.cur_element_count.load();
@@ -208,6 +221,99 @@ struct HnswlibAdapter {
208221
return out;
209222
}
210223

224+
public:
225+
// Restore HNSW graph structure from serialized nodes with metadata
226+
void RestoreFromNodes(const std::vector<HnswNodeData>& nodes, const HnswIndexMetadata& metadata) {
227+
MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kWriteLock);
228+
absl::WriterMutexLock resize_lock(&resize_mutex_);
229+
230+
if (nodes.empty()) {
231+
return;
232+
}
233+
234+
// Ensure we have enough capacity
235+
size_t required_capacity = metadata.cur_element_count;
236+
if (world_.max_elements_ < required_capacity) {
237+
world_.resizeIndex(required_capacity);
238+
}
239+
240+
// Restore each node - directly set up memory and fields
241+
for (const auto& node : nodes) {
242+
size_t internal_id = node.internal_id;
243+
244+
// Register label in lookup table
245+
world_.label_lookup_[node.global_id] = internal_id;
246+
247+
// Set the level
248+
world_.element_levels_[internal_id] = node.level;
249+
250+
// Clear level 0 memory and set label
251+
memset(world_.data_level0_memory_ + internal_id * world_.size_data_per_element_ +
252+
world_.offsetLevel0_,
253+
0, world_.size_data_per_element_);
254+
world_.setExternalLabel(internal_id, node.global_id);
255+
256+
// Allocate upper layer links if needed
257+
if (node.level > 0) {
258+
world_.linkLists_[internal_id] =
259+
(char*)mi_malloc(world_.size_links_per_element_ * node.level + 1);
260+
memset(world_.linkLists_[internal_id], 0, world_.size_links_per_element_ * node.level + 1);
261+
}
262+
263+
// Restore links for layer 0
264+
if (!node.levels_links.empty()) {
265+
auto* ll0 = world_.get_linklist0(internal_id);
266+
world_.setListCount(ll0, node.levels_links[0].size());
267+
auto* links0 = reinterpret_cast<uint32_t*>(ll0 + 1);
268+
std::copy(node.levels_links[0].begin(), node.levels_links[0].end(), links0);
269+
}
270+
271+
// Restore links for upper layers
272+
for (int lvl = 1; lvl <= node.level && lvl < static_cast<int>(node.levels_links.size());
273+
++lvl) {
274+
auto* ll = world_.get_linklist(internal_id, lvl);
275+
world_.setListCount(ll, node.levels_links[lvl].size());
276+
auto* links = reinterpret_cast<uint32_t*>(ll + 1);
277+
std::copy(node.levels_links[lvl].begin(), node.levels_links[lvl].end(), links);
278+
}
279+
}
280+
281+
// Set the metadata for the graph
282+
world_.cur_element_count.store(metadata.cur_element_count);
283+
world_.maxlevel_ = metadata.maxlevel;
284+
world_.enterpoint_node_ = metadata.enterpoint_node;
285+
286+
VLOG(1) << "Restored HNSW index with " << metadata.cur_element_count
287+
<< " nodes, maxlevel=" << metadata.maxlevel
288+
<< ", enterpoint=" << metadata.enterpoint_node;
289+
}
290+
291+
// Update vector data for an existing node (used after RestoreFromNodes)
292+
void UpdateVectorData(GlobalDocId id, const void* data) {
293+
MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kWriteLock);
294+
295+
// Find the internal id for this label
296+
auto it = world_.label_lookup_.find(id);
297+
if (it == world_.label_lookup_.end()) {
298+
LOG(WARNING) << "UpdateVectorData: label " << id << " not found in index";
299+
return;
300+
}
301+
302+
size_t internal_id = it->second;
303+
304+
// Copy/store the vector data based on copy_vector_ mode
305+
if (world_.copy_vector_) {
306+
// Owned mode: copy data into world's vector memory
307+
char* data_ptr = world_.data_vector_memory_ + internal_id * world_.data_size_;
308+
memcpy(data_ptr, data, world_.data_size_);
309+
} else {
310+
// Borrowed mode: store pointer to external data
311+
char* ptr_location = world_.getDataPtrByInternalId(internal_id);
312+
memcpy(ptr_location, &data, sizeof(void*));
313+
}
314+
}
315+
316+
private:
211317
HnswSpace space_;
212318
HierarchicalNSW<float> world_;
213319
absl::Mutex resize_mutex_;
@@ -264,6 +370,10 @@ HnswIndexMetadata HnswVectorIndex::GetMetadata() const {
264370
return adapter_->GetMetadata();
265371
}
266372

373+
void HnswVectorIndex::SetMetadata(const HnswIndexMetadata& metadata) {
374+
adapter_->SetMetadata(metadata);
375+
}
376+
267377
size_t HnswVectorIndex::GetNodeCount() const {
268378
return adapter_->GetNodeCount();
269379
}
@@ -272,4 +382,28 @@ std::vector<HnswNodeData> HnswVectorIndex::GetNodesRange(size_t start, size_t en
272382
return adapter_->GetNodesRange(start, end);
273383
}
274384

385+
void HnswVectorIndex::RestoreFromNodes(const std::vector<HnswNodeData>& nodes,
386+
const HnswIndexMetadata& metadata) {
387+
adapter_->RestoreFromNodes(nodes, metadata);
388+
restored_ = true;
389+
}
390+
391+
bool HnswVectorIndex::UpdateVectorData(GlobalDocId id, const DocumentAccessor& doc,
392+
std::string_view field) {
393+
auto vector_ptr = doc.GetVector(field, dim_);
394+
if (!vector_ptr) {
395+
return false;
396+
}
397+
398+
const void* data = nullptr;
399+
if (std::holds_alternative<OwnedFtVector>(*vector_ptr)) {
400+
data = std::get<OwnedFtVector>(*vector_ptr).first.get();
401+
} else {
402+
data = std::get<BorrowedFtVector>(*vector_ptr);
403+
}
404+
405+
adapter_->UpdateVectorData(id, data);
406+
return true;
407+
}
408+
275409
} // namespace dfly::search

src/core/search/hnsw_index.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,40 @@ class HnswVectorIndex {
6464
// Get metadata for serialization
6565
HnswIndexMetadata GetMetadata() const;
6666

67+
// Set metadata (used during restoration)
68+
void SetMetadata(const HnswIndexMetadata& metadata);
69+
6770
// Get total number of nodes in the index
6871
size_t GetNodeCount() const;
6972

7073
// Get nodes in the specified range [start, end)
7174
// Returns vector of node data for serialization
7275
std::vector<HnswNodeData> GetNodesRange(size_t start, size_t end) const;
7376

77+
// Restore graph structure from serialized nodes with metadata
78+
// This restores the HNSW graph links but NOT the vector data
79+
// Vector data must be populated separately via UpdateVectorData
80+
void RestoreFromNodes(const std::vector<HnswNodeData>& nodes, const HnswIndexMetadata& metadata);
81+
82+
// Update vector data for an existing node (used after RestoreFromNodes)
83+
// This populates the vector data for a node that already has graph links
84+
bool UpdateVectorData(GlobalDocId id, const DocumentAccessor& doc, std::string_view field);
85+
86+
// Mark index as restored from RDB (should use UpdateVectorData instead of Add)
87+
void SetRestored(bool restored) {
88+
restored_ = restored;
89+
}
90+
91+
bool IsRestored() const {
92+
return restored_;
93+
}
94+
7495
private:
7596
bool copy_vector_;
7697
size_t dim_;
7798
VectorSimilarity sim_;
7899
std::unique_ptr<HnswlibAdapter> adapter_;
100+
bool restored_ = false;
79101
};
80102

81103
} // namespace dfly::search

src/server/rdb_load.cc

Lines changed: 107 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "server/rdb_load.h"
66

77
#include "absl/strings/escaping.h"
8+
#include "server/search/global_hnsw_index.h"
89
#include "server/tiered_storage.h"
910

1011
extern "C" {
@@ -2213,8 +2214,8 @@ error_code RdbLoader::Load(io::Source* src) {
22132214
}
22142215

22152216
if (type == RDB_OPCODE_VECTOR_INDEX) {
2216-
// Stub: read and ignore HNSW vector index data
2217-
// Binary format: [index_name, elements_number,
2217+
// Read HNSW vector index graph data and restore directly
2218+
// Binary format: [index_key, elements_number,
22182219
// then for each node (little-endian):
22192220
// internal_id (4 bytes), global_id (8 bytes), level (4 bytes),
22202221
// for each level (0 to level): links_num (4 bytes) + links (4 bytes each)]
@@ -2224,26 +2225,93 @@ error_code RdbLoader::Load(io::Source* src) {
22242225
uint64_t elements_number;
22252226
SET_OR_RETURN(LoadLen(nullptr), elements_number);
22262227

2228+
// Only restore if shard count matches (GlobalDocId encodes shard_id)
2229+
bool should_restore =
2230+
shard_count_ > 0 && shard_set != nullptr && shard_count_ == shard_set->size();
2231+
2232+
// Extract index_name and field_name from index_key
2233+
size_t colon_pos = index_key.find(':');
2234+
string index_name = (colon_pos != string::npos) ? index_key.substr(0, colon_pos) : index_key;
2235+
string field_name = (colon_pos != string::npos) ? index_key.substr(colon_pos + 1) : "";
2236+
2237+
// Check if we can get the HNSW index (it should exist from FT.CREATE in aux)
2238+
auto hnsw_index = should_restore
2239+
? GlobalHnswIndexRegistry::Instance().Get(index_name, field_name)
2240+
: nullptr;
2241+
if (should_restore && !hnsw_index) {
2242+
LOG(WARNING) << "HNSW index not found for restoration: " << index_key;
2243+
should_restore = false;
2244+
}
2245+
2246+
std::vector<search::HnswNodeData> nodes;
2247+
if (should_restore) {
2248+
nodes.reserve(elements_number);
2249+
}
2250+
22272251
for (uint64_t elem = 0; elem < elements_number; ++elem) {
2228-
[[maybe_unused]] uint32_t internal_id;
2252+
uint32_t internal_id;
22292253
SET_OR_RETURN(FetchInt<uint32_t>(), internal_id);
2230-
[[maybe_unused]] uint64_t global_id;
2254+
uint64_t global_id;
22312255
SET_OR_RETURN(FetchInt<uint64_t>(), global_id);
22322256
uint32_t level;
22332257
SET_OR_RETURN(FetchInt<uint32_t>(), level);
22342258

2259+
search::HnswNodeData node;
2260+
if (should_restore) {
2261+
node.internal_id = internal_id;
2262+
node.global_id = global_id;
2263+
node.level = level;
2264+
node.levels_links.resize(level + 1);
2265+
}
2266+
22352267
for (uint32_t lvl = 0; lvl <= level; ++lvl) {
22362268
uint32_t links_num;
22372269
SET_OR_RETURN(FetchInt<uint32_t>(), links_num);
2270+
2271+
if (should_restore) {
2272+
node.levels_links[lvl].reserve(links_num);
2273+
}
2274+
22382275
for (uint32_t i = 0; i < links_num; ++i) {
2239-
[[maybe_unused]] uint32_t link;
2276+
uint32_t link;
22402277
SET_OR_RETURN(FetchInt<uint32_t>(), link);
2278+
if (should_restore) {
2279+
node.levels_links[lvl].push_back(link);
2280+
}
22412281
}
22422282
}
2283+
2284+
if (should_restore) {
2285+
nodes.push_back(std::move(node));
2286+
}
22432287
}
22442288

2245-
VLOG(2) << "Ignoring HNSW vector index: " << index_key
2246-
<< " elements_number=" << elements_number;
2289+
if (should_restore && !nodes.empty()) {
2290+
// Get metadata - it was set via SetMetadataForIndex after FT.CREATE
2291+
search::HnswIndexMetadata metadata = hnsw_index->GetMetadata();
2292+
2293+
if (metadata.cur_element_count == 0) {
2294+
// Create default metadata from graph data
2295+
metadata.cur_element_count = nodes.size();
2296+
metadata.maxlevel = -1;
2297+
metadata.enterpoint_node = 0;
2298+
for (const auto& node : nodes) {
2299+
if (node.level > metadata.maxlevel) {
2300+
metadata.maxlevel = node.level;
2301+
metadata.enterpoint_node = node.internal_id;
2302+
}
2303+
}
2304+
}
2305+
2306+
// Restore the HNSW graph directly and mark as restored
2307+
hnsw_index->RestoreFromNodes(nodes, metadata);
2308+
2309+
LOG(INFO) << "Restored HNSW index " << index_key << " with " << nodes.size() << " nodes";
2310+
} else if (elements_number > 0) {
2311+
VLOG(2) << "Skipping HNSW vector index restore: " << index_key
2312+
<< " elements_number=" << elements_number << " shard_count_=" << shard_count_
2313+
<< " current_shards=" << (shard_set ? shard_set->size() : 0);
2314+
}
22472315
continue;
22482316
}
22492317

@@ -2975,7 +3043,6 @@ std::vector<std::string> RdbLoader::pending_synonym_cmds_;
29753043
// Static synchronization for thread-safe search index creation
29763044
base::SpinLock RdbLoader::search_index_mu_;
29773045
absl::flat_hash_set<std::string> RdbLoader::created_search_indices_;
2978-
29793046
std::vector<std::string> RdbLoader::TakePendingSynonymCommands() {
29803047
std::vector<std::string> result;
29813048
result.swap(pending_synonym_cmds_);
@@ -2985,6 +3052,8 @@ std::vector<std::string> RdbLoader::TakePendingSynonymCommands() {
29853052
void RdbLoader::LoadSearchIndexDefFromAux(string&& def) {
29863053
string index_name;
29873054
string full_cmd;
3055+
string hnsw_field_name;
3056+
std::optional<search::HnswIndexMetadata> hnsw_meta;
29883057

29893058
// Check if this is new JSON format (starts with '{') or old format ("index_name cmd")
29903059
if (!def.empty() && def[0] == '{') {
@@ -2998,9 +3067,28 @@ void RdbLoader::LoadSearchIndexDefFromAux(string&& def) {
29983067
const auto& json = *json_opt;
29993068
index_name = json["name"].as<string>();
30003069
string cmd = json["cmd"].as<string>();
3001-
3002-
// TODO: restore HNSW metadata from json["hnsw_metadata"] if present
3003-
// Currently we just restore the index definition, HNSW graph will be rebuilt
3070+
hnsw_field_name = json["field"].as<string>();
3071+
3072+
// Parse HNSW metadata if present
3073+
if (json.contains("hnsw_metadata")) {
3074+
const auto& meta = json["hnsw_metadata"];
3075+
search::HnswIndexMetadata m;
3076+
m.max_elements = meta["max_elements"].as<size_t>();
3077+
m.cur_element_count = meta["cur_element_count"].as<size_t>();
3078+
m.maxlevel = meta["maxlevel"].as<int>();
3079+
m.enterpoint_node = meta["enterpoint_node"].as<size_t>();
3080+
m.M = meta["M"].as<size_t>();
3081+
m.maxM = meta["maxM"].as<size_t>();
3082+
m.maxM0 = meta["maxM0"].as<size_t>();
3083+
m.ef_construction = meta["ef_construction"].as<size_t>();
3084+
m.mult = meta["mult"].as<double>();
3085+
hnsw_meta = m;
3086+
3087+
VLOG(1) << "Parsed HNSW metadata for index " << index_name << " field " << hnsw_field_name
3088+
<< ": max_elements=" << m.max_elements
3089+
<< " cur_element_count=" << m.cur_element_count << " maxlevel=" << m.maxlevel
3090+
<< " M=" << m.M;
3091+
}
30043092

30053093
full_cmd = absl::StrCat(index_name, " ", cmd);
30063094
} catch (const std::exception& e) {
@@ -3032,6 +3120,13 @@ void RdbLoader::LoadSearchIndexDefFromAux(string&& def) {
30323120
}
30333121

30343122
LoadSearchCommandFromAux(service_, std::move(full_cmd), "FT.CREATE", "index definition");
3123+
3124+
// Store metadata on HNSW index after index creation (for later graph restoration)
3125+
if (hnsw_meta && !hnsw_field_name.empty()) {
3126+
if (auto index = GlobalHnswIndexRegistry::Instance().Get(index_name, hnsw_field_name); index) {
3127+
index->SetMetadata(*hnsw_meta);
3128+
}
3129+
}
30353130
}
30363131

30373132
void RdbLoader::LoadSearchSynonymsFromAux(string&& def) {
@@ -3055,7 +3150,7 @@ void RdbLoader::PerformPostLoad(Service* service, bool is_error) {
30553150
if (is_error)
30563151
return;
30573152

3058-
// Rebuild all search indices as only their definitions are extracted from the snapshot
3153+
// Rebuild all search indices - for restored HNSW indices, this will populate vectors
30593154
shard_set->AwaitRunningOnShardQueue([](EngineShard* es) {
30603155
OpArgs op_args{es, nullptr,
30613156
DbContext{&namespaces->GetDefaultNamespace(), 0, GetCurrentTimeMs()}};

0 commit comments

Comments
 (0)