Skip to content

Commit 4eb5822

Browse files
committed
test: unittest for hnsw
1 parent 81f0683 commit 4eb5822

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

tests/test_hnsw.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import unittest
2+
import sys
3+
sys.path.append('..')
4+
sys.path.append("build")
5+
import numpy as np
6+
from zenann import HNSWIndex
7+
8+
class TestHNSWRecall(unittest.TestCase):
9+
def setUp(self):
10+
# one-hot vectors
11+
self.dim = 8
12+
self.data = [ [1.0 if i==j else 0.0 for i in range(self.dim)] for j in range(self.dim) ]
13+
self.M = self.dim
14+
self.efC = 16
15+
self.efS = self.dim
16+
self.index = HNSWIndex(dim=self.dim, M=self.M, efConstruction=self.efC)
17+
self.index.set_ef_search(self.efS)
18+
self.index.build(self.data)
19+
20+
def test_perfect_recall_k1(self):
21+
for true_id, q in enumerate(self.data):
22+
res = self.index.search(q, k=1)
23+
self.assertEqual(len(res.indices), 1)
24+
self.assertEqual(res.indices[0], true_id)
25+
self.assertAlmostEqual(res.distances[0], 0.0, places=6)
26+
27+
def test_recall_rate_k2(self):
28+
hits = 0
29+
for true_id, q in enumerate(self.data):
30+
res = self.index.search(q, k=2)
31+
self.assertGreaterEqual(len(res.indices), 1)
32+
if true_id in res.indices:
33+
hits += 1
34+
recall = hits / len(self.data)
35+
self.assertEqual(recall, 1.0)
36+
37+
if __name__ == '__main__':
38+
unittest.main()

0 commit comments

Comments
 (0)