forked from aponom84/navigable-graphs-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest-hnsw.py
More file actions
139 lines (107 loc) · 5.43 KB
/
test-hnsw.py
File metadata and controls
139 lines (107 loc) · 5.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# #!/usr/bin/env python
# # coding: utf-8
# import numpy as np
# import argparse
# from tqdm import tqdm
# # from tqdm import tqdm_notebook as tqdm
# from heapq import heappush, heappop
# import random
# import itertools
# random.seed(108)
# from hnsw import HNSW
# from hnsw import l2_distance, heuristic
# def brute_force_knn_search(distance_func, k, q, data):
# '''
# Return the list of (idx, dist) for k-closest elements to {x} in {data}
# '''
# return sorted(enumerate(map(lambda x: distance_func(q, x) ,data)), key=lambda a: a[1])[:k]
# def calculate_recall(distance_func, kg, test, groundtruth, k, ef, m):
# if groundtruth is None:
# print("Ground truth not found. Calculating ground truth...")
# groundtruth = [ [idx for idx, dist in brute_force_knn_search(distance_func, k, query, kg.data)] for query in tqdm(test)]
# print("Calculating recall...")
# recalls = []
# total_calc = 0
# for query, true_neighbors in tqdm(zip(test, groundtruth), total=len(test)):
# true_neighbors = true_neighbors[:k] # Use only the top k ground truth neighbors
# entry_points = random.sample(range(len(kg.data)), m)
# observed = [neighbor for neighbor, dist in kg.search(q=query, k=k, ef=ef, return_observed = True)]
# total_calc = total_calc + len(observed)
# results = observed[:k]
# intersection = len(set(true_neighbors).intersection(set(results)))
# # print(f'true_neighbors: {true_neighbors}, results: {results}. Intersection: {intersection}')
# recall = intersection / k
# recalls.append(recall)
# return np.mean(recalls), total_calc/len(test)
# def read_fvecs(filename):
# with open(filename, 'rb') as f:
# while True:
# vec_size = np.fromfile(f, dtype=np.int32, count=1)
# if not vec_size:
# break
# vec = np.fromfile(f, dtype=np.float32, count=vec_size[0])
# yield vec
# def read_ivecs(filename):
# with open(filename, 'rb') as f:
# while True:
# vec_size = np.fromfile(f, dtype=np.int32, count=1)
# if not vec_size:
# break
# vec = np.fromfile(f, dtype=np.int32, count=vec_size[0])
# yield vec
# def load_sift_dataset():
# train_file = 'datasets/siftsmall/siftsmall_base.fvecs'
# test_file = 'datasets/siftsmall/siftsmall_query.fvecs'
# groundtruth_file = 'datasets/siftsmall/siftsmall_groundtruth.ivecs'
# train_data = np.array(list(read_fvecs(train_file)))
# test_data = np.array(list(read_fvecs(test_file)))
# groundtruth_data = np.array(list(read_ivecs(groundtruth_file)))
# return train_data, test_data, groundtruth_data
# def generate_synthetic_data(dim, n, nq):
# train_data = np.random.random((n, dim)).astype(np.float32)
# test_data = np.random.random((nq, dim)).astype(np.float32)
# return train_data, test_data
# def main():
# parser = argparse.ArgumentParser(description='Test recall of beam search method with KGraph.')
# parser.add_argument('--dataset', choices=['synthetic', 'sift'], default='synthetic', help="Choose the dataset to use: 'synthetic' or 'sift'.")
# parser.add_argument('--K', type=int, default=5, help='The size of the neighbourhood')
# parser.add_argument('--M', type=int, default=32, help='Avg number of neighbors')
# parser.add_argument('--M0', type=int, default=64, help='Avg number of neighbors')
# parser.add_argument('--dim', type=int, default=2, help='Dimensionality of synthetic data (ignored for SIFT).')
# parser.add_argument('--n', type=int, default=200, help='Number of training points for synthetic data (ignored for SIFT).')
# parser.add_argument('--nq', type=int, default=50, help='Number of query points for synthetic data (ignored for SIFT).')
# parser.add_argument('--k', type=int, default=5, help='Number of nearest neighbors to search in the test stage')
# parser.add_argument('--ef', type=int, default=10, help='Size of the beam for beam search.')
# parser.add_argument('--m', type=int, default=3, help='Number of random entry points.')
# args = parser.parse_args()
# # Load dataset
# if args.dataset == 'sift':
# print("Loading SIFT dataset...")
# train_data, test_data, groundtruth_data = load_sift_dataset()
# else:
# print(f"Generating synthetic dataset with {args.dim}-dimensional space...")
# train_data, test_data = generate_synthetic_data(args.dim, args.n, args.nq)
# groundtruth_data = None
# # Create HNSW
# for _ef in range(0, 100, 10):
# hnsw = HNSW( distance_func=l2_distance, m=args.M, m0=args.M0, ef=_ef, ef_construction=30, neighborhood_construction = heuristic)
# # Add data to HNSW
# for x in tqdm(train_data):
# hnsw.add(x)
# print(f"ef = {_ef}")
# # Calculate recall
# recall, avg_cal = calculate_recall(l2_distance, hnsw, test_data, groundtruth_data, k=args.k, ef=args.ef, m=args.m)
# print(f"Average recall: {recall}, avg calc: {avg_cal}")
# print("------")
# if __name__ == "__main__":
# main()\
import matplotlib.pyplot as plt
import pandas as pd
df = pd.read_csv('res_.csv')
print(df.columns)
plt.plot(df["ef"], df["avg_calc_baseline"], label = "line 1")
plt.plot(df["ef"], df["avg_calc_my"], label = "line 1")
plt.legend(['baseline', 'mine'])
plt.xlabel("ef") # +
plt.ylabel("recall") # +
plt.savefig('avg_calc.png')