Skip to content

Commit b0b7748

Browse files
committed
further batch querying method update
1 parent c827742 commit b0b7748

File tree

1 file changed

+111
-21
lines changed

1 file changed

+111
-21
lines changed

ms2query/database/ann_vector_index.py

Lines changed: 111 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -485,51 +485,141 @@ def build_index(
485485

486486
def query(
487487
self,
488-
query_fp: Tuple[np.ndarray, np.ndarray] | sp.csr_matrix,
488+
query_fp: (
489+
Tuple[np.ndarray, np.ndarray]
490+
| sp.csr_matrix
491+
| Sequence[Tuple[np.ndarray, np.ndarray]]
492+
),
489493
k: int = 10,
490494
*,
491495
ef: Optional[int] = None,
492496
re_rank: bool = True,
493497
candidate_multiplier: int = 5,
494-
) -> List[Tuple[int, float]]:
498+
num_threads: int = 0,
499+
) -> List[Tuple[int, float]] | List[List[Tuple[int, float]]]:
495500
"""
496501
Query for k nearest neighbors.
497502
498503
Parameters
499504
----------
500-
query_fp : (indices, values) tuple or single-row CSR
501-
k : Number of results
502-
re_rank : Use exact Tanimoto re-ranking
503-
candidate_multiplier : Fetch k * multiplier candidates for re-ranking
505+
query_fp :
506+
- Single query:
507+
* (indices, values) tuple
508+
* single-row CSR of shape (1, dim)
509+
- Batched queries:
510+
* CSR of shape (N, dim)
511+
* Sequence of (indices, values) tuples
512+
k : int
513+
Number of results per query.
514+
re_rank : bool
515+
Use exact Tanimoto re-ranking.
516+
candidate_multiplier : int
517+
Fetch k * multiplier candidates for re-ranking.
518+
num_threads : int
519+
Number of threads to use inside nmslib (0 = library default).
504520
505-
Returns list of (comp_id, similarity) tuples.
521+
Returns
522+
-------
523+
Union[List[Tuple[int, float]], List[List[Tuple[int, float]]]]
524+
- For a single query, returns a list of (comp_id, similarity).
525+
- For multiple queries, returns a list (per query) of such lists.
506526
"""
507527
if self._index is None:
508528
raise RuntimeError("Index not built or loaded.")
509529

510-
q = self._normalize_query(query_fp)
511-
if q.nnz == 0:
512-
return []
530+
# -------------------------
531+
# Normalize input to CSR
532+
# -------------------------
533+
single = False
534+
535+
if isinstance(query_fp, sp.csr_matrix):
536+
Q = query_fp.astype(np.float32, copy=False)
537+
if Q.shape[1] != self.dim:
538+
raise ValueError(f"CSR query must have shape (N, {self.dim})")
539+
single = Q.shape[0] == 1
540+
541+
elif isinstance(query_fp, tuple):
542+
# Single (indices, values)
543+
Q = csr_row_from_tuple(query_fp, dim=self.dim)
544+
single = True
545+
546+
else:
547+
# Assume sequence of (indices, values) tuples -> batched queries
548+
Q = tuples_to_csr(query_fp, dim=self.dim)
549+
single = Q.shape[0] == 1
550+
551+
if (Q.data < 0).any():
552+
raise ValueError("Query must be non-negative for Tanimoto.")
553+
554+
# Handle completely empty queries quickly
555+
row_nnz = Q.indptr[1:] - Q.indptr[:-1]
556+
if row_nnz.sum() == 0:
557+
if single:
558+
return []
559+
return [[] for _ in range(Q.shape[0])]
513560

514561
if ef is not None:
515562
self._index.setQueryTimeParams({"ef": ef})
516563

517564
fetch = max(k, k * candidate_multiplier)
518-
idxs, dists = self._index.knnQueryBatch(q, k=fetch)[0]
519-
idxs = np.asarray(idxs, dtype=np.int64)
520-
dists = np.asarray(dists, dtype=np.float32)
521565

522-
# Without re-ranking, return cosine similarities
566+
# -------------------------
567+
# ANN search for all queries
568+
# -------------------------
569+
batch_results = self._index.knnQueryBatch(Q, k=fetch, num_threads=num_threads)
570+
571+
# -------------------------
572+
# No re-ranking: cosine sims only
573+
# -------------------------
523574
if not re_rank or self._csr is None or self._l1 is None:
524-
sims = 1.0 - dists
525-
return [(self._comp_ids[i], float(s)) for i, s in zip(idxs[:k], sims[:k])]
575+
all_out: List[List[Tuple[int, float]]] = []
576+
577+
for qi, (idxs, dists) in enumerate(batch_results):
578+
if row_nnz[qi] == 0:
579+
all_out.append([])
580+
continue
581+
582+
idxs = np.asarray(idxs, dtype=np.int64)
583+
dists = np.asarray(dists, dtype=np.float32)
584+
585+
sims = 1.0 - dists
586+
out = [
587+
(self._comp_ids[i], float(s))
588+
for i, s in zip(idxs[:k], sims[:k])
589+
]
590+
all_out.append(out)
591+
592+
return all_out[0] if single else all_out
593+
594+
# -------------------------
595+
# Exact Tanimoto re-ranking
596+
# -------------------------
597+
all_out: List[List[Tuple[int, float]]] = []
526598

527-
# Re-rank with exact Tanimoto
528-
Y = self._csr[idxs]
529-
tan = tanimoto_l1_query_vs_block(q, Y, sum1=float(q.sum()), sumsY=self._l1[idxs])
530-
order = np.argsort(-tan)[:k]
599+
for qi, (idxs, dists) in enumerate(batch_results):
600+
if row_nnz[qi] == 0:
601+
all_out.append([])
602+
continue
531603

532-
return [(self._comp_ids[idxs[i]], float(tan[i])) for i in order]
604+
idxs = np.asarray(idxs, dtype=np.int64)
605+
606+
q_row = Q[qi]
607+
Y = self._csr[idxs]
608+
tan = tanimoto_l1_query_vs_block(
609+
q_row,
610+
Y,
611+
sum1=float(q_row.sum()),
612+
sumsY=self._l1[idxs],
613+
)
614+
615+
order = np.argsort(-tan)[:k]
616+
out = [
617+
(self._comp_ids[idxs[i]], float(tan[i]))
618+
for i in order
619+
]
620+
all_out.append(out)
621+
622+
return all_out[0] if single else all_out
533623

534624
def _normalize_query(self, query_fp) -> sp.csr_matrix:
535625
"""Convert query to single-row CSR and validate."""

0 commit comments

Comments
 (0)