Skip to content

Commit 90945a8

Browse files
committed
Used mask
1 parent 0d034de commit 90945a8

File tree

5 files changed

+12
-6
lines changed

5 files changed

+12
-6
lines changed

vicinity/backends/annoy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,5 +130,6 @@ def threshold(self, vectors: npt.NDArray, threshold: float, max_k: int) -> Query
130130
"""Threshold the backend."""
131131
out: QueryResult = []
132132
for x, y in self.query(vectors, max_k):
133-
out.append((x[y < threshold], y[y < threshold]))
133+
mask = y < threshold
134+
out.append((x[mask], y[mask]))
134135
return out

vicinity/backends/faiss.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,13 +179,15 @@ def threshold(self, vectors: npt.NDArray, threshold: float, max_k: int) -> Query
179179
dist = D[start:end]
180180
if self.arguments.metric == "cosine":
181181
dist = 1 - dist
182-
out.append((idx[dist < threshold], dist[dist < threshold]))
182+
mask = dist < threshold
183+
out.append((idx[mask], dist[mask]))
183184
else:
184185
distances, indices = self.index.search(vectors, max_k)
185186
for dist, idx in zip(distances, indices):
186187
if self.arguments.metric == "cosine":
187188
dist = 1 - dist
188-
out.append((idx[dist < threshold], dist[dist < threshold]))
189+
mask = dist < threshold
190+
out.append((idx[mask], dist[mask]))
189191

190192
return out
191193

vicinity/backends/hnsw.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def threshold(self, vectors: npt.NDArray, threshold: float, max_k: int) -> Query
108108
"""Threshold the backend."""
109109
out: QueryResult = []
110110
for x, y in self.query(vectors, max_k):
111-
out.append((x[y < threshold], y[y < threshold]))
111+
mask = y < threshold
112+
out.append((x[mask], y[mask]))
112113

113114
return out

vicinity/backends/pynndescent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ def threshold(self, vectors: npt.NDArray, threshold: float, max_k: int) -> Query
8686
indices, distances = self.index.query(normalized_vectors, k=max_k)
8787
out: QueryResult = []
8888
for idx, dist in zip(indices, distances):
89-
out.append((idx[dist < threshold], dist[dist < threshold]))
89+
mask = dist < threshold
90+
out.append((idx[mask], dist[mask]))
9091
return out
9192

9293
def save(self, base_path: Path) -> None:

vicinity/backends/voyager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ def threshold(self, vectors: npt.NDArray, threshold: float, max_k: int) -> Query
9898
"""Threshold the backend."""
9999
out: list[tuple[npt.NDArray, npt.NDArray]] = []
100100
for x, y in self.query(vectors, max_k):
101-
out.append((x[y < threshold], y[y < threshold]))
101+
mask = y < threshold
102+
out.append((x[mask], y[mask]))
102103

103104
return out
104105

0 commit comments

Comments
 (0)