|
4 | 4 | import numpy as np |
5 | 5 | from time import perf_counter |
6 | 6 | from threading import Lock, Condition |
7 | | -import hnswlib |
8 | 7 | import os |
9 | 8 | from usearch.index import Index, MetricKind |
10 | 9 | MAX_THREADS = min(4, max(1, os.cpu_count() - 1)) |
11 | | -# TODO: Consider using the hnswlib library for approximate nearest neighbor for large datasets (n >= 1e5) |
12 | | -# TODO: Consider Usearch library for approximate nearest neighbor |
13 | 10 |
|
14 | 11 | class USearchTree: |
15 | 12 | def __init__(self, data, space='l2sq', ef_construction=200, M=100, ef=20): |
@@ -106,91 +103,6 @@ def replace(self, data, ids): |
106 | 103 | self.index.remove(ids) |
107 | 104 | self.index.add(ids, data) |
108 | 105 |
|
109 | | -class HNSWTree: |
110 | | - def __init__(self, data, space='l2', ef_construction=200, M=16, num_elements=int(4e6), ef=20): |
111 | | - """ |
112 | | - Initialize HNSW tree with the given data. |
113 | | -
|
114 | | - Parameters: |
115 | | - - data: A 2D numpy array of shape (N, D), where N is the number of points, and D is the dimensionality. |
116 | | - - space: Metric to use for distance calculation. Options: 'l2' (Euclidean), 'ip' (inner product). |
117 | | - - ef_construction: The size of the dynamic candidate list during construction. |
118 | | - - M: Controls the number of bi-directional links created for every new element during construction. |
119 | | - """ |
120 | | - self.num_elements = num_elements |
121 | | - _, self.dim = data.shape |
122 | | - self.index = hnswlib.Index(space=space, dim=self.dim) |
123 | | - self.index.init_index(max_elements=self.num_elements, ef_construction=ef_construction, M=M, |
124 | | - allow_replace_deleted=True) |
125 | | - threads = os.cpu_count() // 2 - 1 |
126 | | - self.index.set_num_threads(threads) |
127 | | - labels = np.arange(data.shape[0]) |
128 | | - self.index.add_items(data, labels) |
129 | | - |
130 | | - # Control the trade-off between accuracy and speed at query time |
131 | | - self.index.set_ef(ef) |
132 | | - |
133 | | - def query(self, points, k=1): |
134 | | - """ |
135 | | - Query the nearest neighbors for the given points. |
136 | | -
|
137 | | - Parameters: |
138 | | - - points: A 2D numpy array of shape (N, D) where N is the number of query points. |
139 | | - - k: The number of nearest neighbors to return. |
140 | | -
|
141 | | - Returns: |
142 | | - - distances: A 2D array of distances to the nearest neighbors. |
143 | | - - indices: A 2D array of indices of the nearest neighbors. |
144 | | - """ |
145 | | - indices, distances = self.index.knn_query(points, k=k) |
146 | | - return distances, indices |
147 | | - |
148 | | - def query_ball_point(self, points, r, **kwargs): |
149 | | - """ |
150 | | - Query points within distance `r` from the given points. |
151 | | -
|
152 | | - Parameters: |
153 | | - - points: A 2D numpy array of shape (N, D) where N is the number of query points. |
154 | | - - r: The radius to search within. |
155 | | -
|
156 | | - Returns: |
157 | | - - A list of lists, where each sublist contains the indices of points within distance `r` of the corresponding query point. |
158 | | - """ |
159 | | - # Approximate a ball query by performing knn search with a large `k` and filtering by `r` |
160 | | - k = kwargs.get("k", min(50, self.index.element_count)) # A large `k` value to ensure enough neighbors are found |
161 | | - indices, distances = self.index.knn_query(points, k=k) |
162 | | - |
163 | | - # Filter based on radius `r` |
164 | | - result = [] |
165 | | - for i in range(len(points)): |
166 | | - within_r = indices[i][distances[i] <= r].tolist() # Filter neighbors by radius |
167 | | - result.append(within_r) |
168 | | - |
169 | | - return result |
170 | | - |
171 | | - def set_ef(self, ef): |
172 | | - """ |
173 | | - Set the `ef` parameter, which controls the search depth and the trade-off between speed and accuracy. |
174 | | -
|
175 | | - Parameters: |
176 | | - - ef: The size of the dynamic list for nearest neighbors search. |
177 | | - """ |
178 | | - self.index.set_ef(ef) |
179 | | - |
180 | | - def add_items(self, data, ids): |
181 | | - """ |
182 | | - Add more points to the HNSW index. |
183 | | -
|
184 | | - Parameters: |
185 | | - - data: A 2D numpy array of new points to add. |
186 | | - """ |
187 | | - self.index.add_items(data, ids, num_threads=1) |
188 | | - self.num_elements += data.shape[0] |
189 | | - |
190 | | - def replace(self, data, ids): |
191 | | - [self.index.mark_deleted(i) for i in ids] |
192 | | - self.index.add_items(data, ids, num_threads=1, replace_deleted=True) |
193 | | - |
194 | 106 | class ReadWriteLock: |
195 | 107 | def __init__(self): |
196 | 108 | self._read_ready = Condition(Lock()) |
|
0 commit comments