Skip to content

Commit d4c881d

Browse files
committed
Add recall test for hnsw via python bindings
1 parent a6af73d commit d4c881d

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed

examples/recall_test.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import hnswlib
2+
import numpy as np
3+
4+
dim = 128
5+
num_elements = 100000
6+
k = 10
7+
nun_queries = 10
8+
9+
# Generating sample data
10+
data = np.float32(np.random.random((num_elements, dim)))
11+
12+
# Declaring index
13+
hnsw_index = hnswlib.Index(space='l2', dim=dim) # possible options are l2, cosine or ip
14+
bf_index = hnswlib.BFIndex(space='l2', dim=dim)
15+
16+
# Initing both hnsw and brute force indices
17+
# max_elements - the maximum number of elements (capacity). Will throw an exception if exceeded
18+
# during insertion of an element.
19+
# The capacity can be increased by saving/loading the index, see below.
20+
#
21+
# hnsw construction params:
22+
# ef_construction - controls index search speed/build speed tradeoff
23+
#
24+
# M - is tightly connected with internal dimensionality of the data. Strongly affects the memory consumption (~M)
25+
# Higher M leads to higher accuracy/run_time at fixed ef/efConstruction
26+
27+
hnsw_index.init_index(max_elements=num_elements, ef_construction=10, M=6)
28+
bf_index.init_index(max_elements=num_elements)
29+
30+
# Controlling the recall for hnsw by setting ef:
31+
# higher ef leads to better accuracy, but slower search
32+
hnsw_index.set_ef(10)
33+
34+
# Set number of threads used during batch search/construction in hnsw
35+
# By default using all available cores
36+
hnsw_index.set_num_threads(1)
37+
38+
print("Adding batch of %d elements" % (len(data)))
39+
hnsw_index.add_items(data)
40+
bf_index.add_items(data)
41+
42+
print("Indices built")
43+
44+
# Generating query data
45+
query_data = np.float32(np.random.random((10, dim)))
46+
47+
# Query the elements and measure recall:
48+
labels_hnsw, distances_hnsw = hnsw_index.knn_query(query_data, k)
49+
labels_bf, distances_bf = bf_index.knn_query(query_data, k)
50+
51+
# Measure recall
52+
correct = 0
53+
for i in range(nun_queries):
54+
for label in labels_hnsw[i]:
55+
for correct_label in labels_bf[i]:
56+
if label == correct_label:
57+
correct += 1
58+
break
59+
60+
print("recall is :", float(correct)/(k*nun_queries))

0 commit comments

Comments
 (0)