Skip to content

Commit 81f0683

Browse files
committed
feat: IndexHNSW implementation with faiss
1 parent 5409459 commit 81f0683

File tree

2 files changed

+162
-0
lines changed

2 files changed

+162
-0
lines changed

include/zenann/HNSWIndex.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#pragma once
2+
#include "IndexBase.h"
3+
#include <faiss/IndexHNSW.h>
4+
#include <string>
5+
#include <vector>
6+
7+
namespace zenann {
8+
9+
class HNSWIndex : public IndexBase {
10+
public:
11+
HNSWIndex(size_t dim, size_t M, size_t efConstruction = 200);
12+
~HNSWIndex() override;
13+
void train() override;
14+
SearchResult search(const Vector& query, size_t k) const override;
15+
SearchResult search(const Vector& query, size_t k, size_t efSearch) const;
16+
std::vector<SearchResult> search_batch(const Dataset& queries, size_t k) const;
17+
std::vector<SearchResult> search_batch(const Dataset& queries, size_t k, size_t efSearch) const;
18+
19+
void set_ef_search(size_t efSearch);
20+
void reorder_layout();
21+
void reorder_layout(const std::string& mapping_file);
22+
23+
void write_index(const std::string& filename) const;
24+
static std::shared_ptr<HNSWIndex> read_index(const std::string& filename);
25+
26+
static double compute_recall_with_mapping(
27+
const std::vector<std::vector<faiss::idx_t>>& groundtruth,
28+
const std::vector<faiss::idx_t>& predicted_flat,
29+
size_t nq, size_t k,
30+
const std::string& mapping_file);
31+
32+
private:
33+
std::unique_ptr<faiss::IndexHNSWFlat> idx_;
34+
};
35+
36+
}

src/HNSWIndex.cpp

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
#include "HNSWIndex.h"
2+
#include <faiss/index_io.h>
3+
#include <vector>
4+
#include <algorithm>
5+
#include <fstream>
6+
7+
namespace zenann {
8+
9+
HNSWIndex::HNSWIndex(size_t dim, size_t M, size_t efConstruction)
10+
: IndexBase(dim), idx_(new faiss::IndexHNSWFlat(dim, M)) {
11+
idx_->hnsw.efConstruction = efConstruction;
12+
}
13+
14+
HNSWIndex::~HNSWIndex() = default;
15+
16+
void HNSWIndex::train() {
17+
const auto& data = datastore_->getAll();
18+
if (data.empty()) return;
19+
size_t n = data.size();
20+
std::vector<float> flat(n * dimension_);
21+
for (size_t i = 0; i < n; ++i) {
22+
std::copy(data[i].begin(), data[i].end(), flat.begin() + i * dimension_);
23+
}
24+
idx_->add(n, flat.data());
25+
}
26+
27+
SearchResult HNSWIndex::search(const Vector& query, size_t k) const {
28+
std::vector<faiss::idx_t> labels(k);
29+
std::vector<float> distances(k);
30+
idx_->search(1, query.data(), k, distances.data(), labels.data());
31+
32+
SearchResult result;
33+
result.indices.assign(labels.begin(), labels.end());
34+
result.distances.assign(distances.begin(), distances.end());
35+
return result;
36+
}
37+
38+
SearchResult HNSWIndex::search(const Vector& query, size_t k, size_t efSearch) const {
39+
idx_->hnsw.efSearch = efSearch;
40+
return search(query, k);
41+
}
42+
43+
std::vector<SearchResult> HNSWIndex::search_batch(const Dataset& queries, size_t k) const {
44+
return search_batch(queries, k, idx_->hnsw.efSearch);
45+
}
46+
47+
std::vector<SearchResult> HNSWIndex::search_batch(const Dataset& queries, size_t k, size_t efSearch) const {
48+
idx_->hnsw.efSearch = efSearch;
49+
size_t nq = queries.size();
50+
std::vector<float> flat(nq * dimension_);
51+
for (size_t i = 0; i < nq; ++i) {
52+
std::copy(queries[i].begin(), queries[i].end(), flat.begin() + i * dimension_);
53+
}
54+
std::vector<faiss::idx_t> labels(nq * k);
55+
std::vector<float> distances(nq * k);
56+
idx_->search(nq, flat.data(), k, distances.data(), labels.data());
57+
58+
std::vector<SearchResult> results(nq);
59+
for (size_t i = 0; i < nq; ++i) {
60+
auto lb = labels.begin() + i * k;
61+
auto dd = distances.begin() + i * k;
62+
results[i].indices.assign(lb, lb + k);
63+
results[i].distances.assign(dd, dd + k);
64+
}
65+
return results;
66+
}
67+
68+
void HNSWIndex::set_ef_search(size_t efSearch) {
69+
idx_->hnsw.efSearch = efSearch;
70+
}
71+
72+
void HNSWIndex::reorder_layout() {
73+
idx_->bfs_reorder();
74+
}
75+
76+
void HNSWIndex::reorder_layout(const std::string& mapping_file) {
77+
auto new_order = idx_->bfs_reorder();
78+
std::ofstream ofs(mapping_file, std::ios::binary);
79+
size_t sz = new_order.size();
80+
ofs.write(reinterpret_cast<const char*>(&sz), sizeof(sz));
81+
ofs.write(reinterpret_cast<const char*>(new_order.data()), sz * sizeof(int));
82+
}
83+
84+
85+
void HNSWIndex::write_index(const std::string& filename) const {
86+
faiss::write_index(idx_.get(), filename.c_str());
87+
}
88+
89+
std::shared_ptr<HNSWIndex> HNSWIndex::read_index(const std::string& filename) {
90+
faiss::Index* base = faiss::read_index(filename.c_str());
91+
auto raw = dynamic_cast<faiss::IndexHNSWFlat*>(base);
92+
auto inst = std::make_shared<HNSWIndex>(raw->d, raw->hnsw.nb_neighbors(0));
93+
inst->idx_.reset(raw);
94+
return inst;
95+
}
96+
97+
double HNSWIndex::compute_recall_with_mapping(
98+
const std::vector<std::vector<faiss::idx_t>>& groundtruth,
99+
const std::vector<faiss::idx_t>& predicted_flat,
100+
size_t nq, size_t k,
101+
const std::string& mapping_file) {
102+
std::ifstream ifs(mapping_file, std::ios::binary);
103+
size_t sz;
104+
ifs.read(reinterpret_cast<char*>(&sz), sizeof(sz));
105+
std::vector<int> old_to_new(sz);
106+
ifs.read(reinterpret_cast<char*>(old_to_new.data()), sz * sizeof(int));
107+
std::vector<std::vector<faiss::idx_t>> pred(nq, std::vector<faiss::idx_t>(k));
108+
for (size_t i = 0; i < nq; ++i)
109+
for (size_t j = 0; j < k; ++j)
110+
pred[i][j] = predicted_flat[i*k + j];
111+
std::vector<double> recalls;
112+
for (size_t i = 0; i < nq; ++i) {
113+
std::set<faiss::idx_t> ts;
114+
for (size_t j = 0; j < k; ++j) {
115+
auto oldid = groundtruth[i][j];
116+
ts.insert(old_to_new[oldid]);
117+
}
118+
std::set<faiss::idx_t> ps(pred[i].begin(), pred[i].end());
119+
size_t hits=0;
120+
for (auto pid: ps) if (ts.count(pid)) ++hits;
121+
recalls.push_back(double(hits)/ts.size());
122+
}
123+
return std::accumulate(recalls.begin(), recalls.end(), 0.0)/recalls.size();
124+
}
125+
126+
}

0 commit comments

Comments
 (0)