Skip to content

Commit 079c71e

Browse files
committed
Add load and store index to the bindings, update test recall
1 parent d4c881d commit 079c71e

File tree

5 files changed

+64
-79
lines changed

5 files changed

+64
-79
lines changed

examples/searchKnnCloserFirst_test.cpp

Lines changed: 3 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -9,69 +9,12 @@
99

1010
#include <vector>
1111
#include <iostream>
12-
#include <thread>
1312

1413
namespace
1514
{
1615

1716
using idx_t = hnswlib::labeltype;
1817

19-
template<class Function>
20-
inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn) {
21-
if (numThreads <= 0) {
22-
numThreads = std::thread::hardware_concurrency();
23-
}
24-
25-
if (numThreads == 1) {
26-
for (size_t id = start; id < end; id++) {
27-
fn(id, 0);
28-
}
29-
} else {
30-
std::vector<std::thread> threads;
31-
std::atomic<size_t> current(start);
32-
33-
// keep track of exceptions in threads
34-
// https://stackoverflow.com/a/32428427/1713196
35-
std::exception_ptr lastException = nullptr;
36-
std::mutex lastExceptMutex;
37-
38-
for (size_t threadId = 0; threadId < numThreads; ++threadId) {
39-
threads.push_back(std::thread([&, threadId] {
40-
while (true) {
41-
size_t id = current.fetch_add(1);
42-
43-
if ((id >= end)) {
44-
break;
45-
}
46-
47-
try {
48-
fn(id, threadId);
49-
} catch (...) {
50-
std::unique_lock<std::mutex> lastExcepLock(lastExceptMutex);
51-
lastException = std::current_exception();
52-
/*
53-
* This will work even when current is the largest value that
54-
* size_t can fit, because fetch_add returns the previous value
55-
* before the increment (what will result in overflow
56-
* and produce 0 instead of current + 1).
57-
*/
58-
current = end;
59-
break;
60-
}
61-
}
62-
}));
63-
}
64-
for (auto &thread : threads) {
65-
thread.join();
66-
}
67-
if (lastException) {
68-
std::rethrow_exception(lastException);
69-
}
70-
}
71-
72-
73-
}
74-
7518
void test() {
7619
int d = 4;
7720
idx_t n = 100;
@@ -97,18 +40,10 @@ void test() {
9740
hnswlib::AlgorithmInterface<float>* alg_brute = new hnswlib::BruteforceSearch<float>(&space, 2 * n);
9841
hnswlib::AlgorithmInterface<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, 2 * n);
9942

100-
// for (size_t i = 0; i < n; ++i) {
101-
// alg_brute->addPoint(data.data() + d * i, i);
102-
// alg_hnsw->addPoint(data.data() + d * i, i);
103-
// }
104-
105-
ParallelFor(0, n, 4, [&](size_t i, size_t threadId) {
106-
alg_hnsw->addPoint(data.data() + d * i, i);
107-
});
108-
109-
ParallelFor(0, n, 4, [&](size_t i, size_t threadId) {
43+
for (size_t i = 0; i < n; ++i) {
11044
alg_brute->addPoint(data.data() + d * i, i);
111-
});
45+
alg_hnsw->addPoint(data.data() + d * i, i);
46+
}
11247

11348
// test searchKnnCloserFirst of BruteforceSearch
11449
for (size_t j = 0; j < nq; ++j) {

hnswlib/bruteforce.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ namespace hnswlib {
6868
memcpy(data_ + size_per_element_ * idx, datapoint, data_size_);
6969

7070

71+
72+
7173
};
7274

7375
void removePoint(labeltype cur_external) {
@@ -97,7 +99,6 @@ namespace hnswlib {
9799
dist_t lastdist = topResults.top().first;
98100
for (int i = k; i < cur_element_count; i++) {
99101
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
100-
101102
if (dist <= lastdist) {
102103
topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *) (data_ + size_per_element_ * i +
103104
data_size_))));

python_bindings/__init__.py

Whitespace-only changes.

python_bindings/bindings.cpp

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -743,21 +743,39 @@ class BFIndex {
743743
throw std::runtime_error("wrong dimensionality of the labels");
744744
}
745745
{
746-
int start = 0;
747-
py::gil_scoped_release l;
748746

749-
std::vector<float> norm_array(dim);
750-
for (size_t i = start; i < rows; i++) {
751-
alg->addPoint((void *) items.data(i), (size_t) i);
747+
for (size_t row = 0; row < rows; row++) {
748+
size_t id = ids.size() ? ids.at(row) : cur_l + row;
749+
if (!normalize) {
750+
alg->addPoint((void *) items.data(row), (size_t) id);
751+
} else {
752+
float normalized_vector[dim];
753+
normalize_vector((float *)items.data(row), normalized_vector);
754+
alg->addPoint((void *) normalized_vector, (size_t) id);
755+
}
752756
}
753757
cur_l+=rows;
754758
}
755759
}
756760

757-
void deletedVector(size_t label) {
761+
void deleteVector(size_t label) {
758762
alg->removePoint(label);
759763
}
760764

765+
void saveIndex(const std::string &path_to_index) {
766+
alg->saveIndex(path_to_index);
767+
}
768+
769+
void loadIndex(const std::string &path_to_index, size_t max_elements) {
770+
if (alg) {
771+
std::cerr<<"Warning: Calling load_index for an already inited index. Old index is being deallocated.";
772+
delete alg;
773+
}
774+
alg = new hnswlib::BruteforceSearch<dist_t>(space, path_to_index);
775+
cur_l = alg->cur_element_count;
776+
index_inited = true;
777+
}
778+
761779
py::object knnQuery_return_numpy(py::object input, size_t k = 1) {
762780

763781
py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input);
@@ -885,6 +903,9 @@ PYBIND11_PLUGIN(hnswlib) {
885903
.def("init_index", &BFIndex<float>::init_new_index, py::arg("max_elements"))
886904
.def("knn_query", &BFIndex<float>::knnQuery_return_numpy, py::arg("data"), py::arg("k")=1)
887905
.def("add_items", &BFIndex<float>::addItems, py::arg("data"), py::arg("ids") = py::none())
906+
.def("delete_vector", &BFIndex<float>::deleteVector, py::arg("label"))
907+
.def("save_index", &BFIndex<float>::saveIndex, py::arg("path_to_index"))
908+
.def("load_index", &BFIndex<float>::loadIndex, py::arg("path_to_index"), py::arg("max_elements")=0)
888909
.def("__repr__", [](const BFIndex<float> &a) {
889910
return "<hnswlib.BFIndex(space='" + a.space_name + "', dim="+std::to_string(a.dim)+")>";
890911
});

examples/recall_test.py renamed to python_bindings/tests/bindings_test_recall.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import hnswlib
22
import numpy as np
33

4-
dim = 128
4+
dim = 32
55
num_elements = 100000
66
k = 10
77
nun_queries = 10
@@ -24,12 +24,12 @@
2424
# M - is tightly connected with internal dimensionality of the data. Strongly affects the memory consumption (~M)
2525
# Higher M leads to higher accuracy/run_time at fixed ef/efConstruction
2626

27-
hnsw_index.init_index(max_elements=num_elements, ef_construction=10, M=6)
27+
hnsw_index.init_index(max_elements=num_elements, ef_construction=200, M=16)
2828
bf_index.init_index(max_elements=num_elements)
2929

3030
# Controlling the recall for hnsw by setting ef:
3131
# higher ef leads to better accuracy, but slower search
32-
hnsw_index.set_ef(10)
32+
hnsw_index.set_ef(200)
3333

3434
# Set number of threads used during batch search/construction in hnsw
3535
# By default using all available cores
@@ -42,7 +42,7 @@
4242
print("Indices built")
4343

4444
# Generating query data
45-
query_data = np.float32(np.random.random((10, dim)))
45+
query_data = np.float32(np.random.random((nun_queries, dim)))
4646

4747
# Query the elements and measure recall:
4848
labels_hnsw, distances_hnsw = hnsw_index.knn_query(query_data, k)
@@ -58,3 +58,31 @@
5858
break
5959

6060
print("recall is :", float(correct)/(k*nun_queries))
61+
62+
# test serializing the brute force index
63+
index_path = 'bf_index.bin'
64+
print("Saving index to '%s'" % index_path)
65+
bf_index.save_index(index_path)
66+
del bf_index
67+
68+
# Re-initiating, loading the index
69+
bf_index = hnswlib.BFIndex(space='l2', dim=dim)
70+
71+
print("\nLoading index from '%s'\n" % index_path)
72+
bf_index.load_index(index_path)
73+
74+
# Query the brute force index again to verify that we get the same results
75+
labels_bf, distances_bf = bf_index.knn_query(query_data, k)
76+
77+
# Measure recall
78+
correct = 0
79+
for i in range(nun_queries):
80+
for label in labels_hnsw[i]:
81+
for correct_label in labels_bf[i]:
82+
if label == correct_label:
83+
correct += 1
84+
break
85+
86+
print("recall after reloading is :", float(correct)/(k*nun_queries))
87+
88+

0 commit comments

Comments
 (0)