@@ -485,51 +485,141 @@ def build_index(
485485
486486 def query (
487487 self ,
488- query_fp : Tuple [np .ndarray , np .ndarray ] | sp .csr_matrix ,
488+ query_fp : (
489+ Tuple [np .ndarray , np .ndarray ]
490+ | sp .csr_matrix
491+ | Sequence [Tuple [np .ndarray , np .ndarray ]]
492+ ),
489493 k : int = 10 ,
490494 * ,
491495 ef : Optional [int ] = None ,
492496 re_rank : bool = True ,
493497 candidate_multiplier : int = 5 ,
494- ) -> List [Tuple [int , float ]]:
498+ num_threads : int = 0 ,
499+ ) -> List [Tuple [int , float ]] | List [List [Tuple [int , float ]]]:
495500 """
496501 Query for k nearest neighbors.
497502
498503 Parameters
499504 ----------
500- query_fp : (indices, values) tuple or single-row CSR
501- k : Number of results
502- re_rank : Use exact Tanimoto re-ranking
503- candidate_multiplier : Fetch k * multiplier candidates for re-ranking
505+ query_fp :
506+ - Single query:
507+ * (indices, values) tuple
508+ * single-row CSR of shape (1, dim)
509+ - Batched queries:
510+ * CSR of shape (N, dim)
511+ * Sequence of (indices, values) tuples
512+ k : int
513+ Number of results per query.
514+ re_rank : bool
515+ Use exact Tanimoto re-ranking.
516+ candidate_multiplier : int
517+ Fetch k * multiplier candidates for re-ranking.
518+ num_threads : int
519+ Number of threads to use inside nmslib (0 = library default).
504520
505- Returns list of (comp_id, similarity) tuples.
521+ Returns
522+ -------
523+ Union[List[Tuple[int, float]], List[List[Tuple[int, float]]]]
524+ - For a single query, returns a list of (comp_id, similarity).
525+ - For multiple queries, returns a list (per query) of such lists.
506526 """
507527 if self ._index is None :
508528 raise RuntimeError ("Index not built or loaded." )
509529
510- q = self ._normalize_query (query_fp )
511- if q .nnz == 0 :
512- return []
530+ # -------------------------
531+ # Normalize input to CSR
532+ # -------------------------
533+ single = False
534+
535+ if isinstance (query_fp , sp .csr_matrix ):
536+ Q = query_fp .astype (np .float32 , copy = False )
537+ if Q .shape [1 ] != self .dim :
538+ raise ValueError (f"CSR query must have shape (N, { self .dim } )" )
539+ single = Q .shape [0 ] == 1
540+
541+ elif isinstance (query_fp , tuple ):
542+ # Single (indices, values)
543+ Q = csr_row_from_tuple (query_fp , dim = self .dim )
544+ single = True
545+
546+ else :
547+ # Assume sequence of (indices, values) tuples -> batched queries
548+ Q = tuples_to_csr (query_fp , dim = self .dim )
549+ single = Q .shape [0 ] == 1
550+
551+ if (Q .data < 0 ).any ():
552+ raise ValueError ("Query must be non-negative for Tanimoto." )
553+
554+ # Handle completely empty queries quickly
555+ row_nnz = Q .indptr [1 :] - Q .indptr [:- 1 ]
556+ if row_nnz .sum () == 0 :
557+ if single :
558+ return []
559+ return [[] for _ in range (Q .shape [0 ])]
513560
514561 if ef is not None :
515562 self ._index .setQueryTimeParams ({"ef" : ef })
516563
517564 fetch = max (k , k * candidate_multiplier )
518- idxs , dists = self ._index .knnQueryBatch (q , k = fetch )[0 ]
519- idxs = np .asarray (idxs , dtype = np .int64 )
520- dists = np .asarray (dists , dtype = np .float32 )
521565
522- # Without re-ranking, return cosine similarities
566+ # -------------------------
567+ # ANN search for all queries
568+ # -------------------------
569+ batch_results = self ._index .knnQueryBatch (Q , k = fetch , num_threads = num_threads )
570+
571+ # -------------------------
572+ # No re-ranking: cosine sims only
573+ # -------------------------
523574 if not re_rank or self ._csr is None or self ._l1 is None :
524- sims = 1.0 - dists
525- return [(self ._comp_ids [i ], float (s )) for i , s in zip (idxs [:k ], sims [:k ])]
575+ all_out : List [List [Tuple [int , float ]]] = []
576+
577+ for qi , (idxs , dists ) in enumerate (batch_results ):
578+ if row_nnz [qi ] == 0 :
579+ all_out .append ([])
580+ continue
581+
582+ idxs = np .asarray (idxs , dtype = np .int64 )
583+ dists = np .asarray (dists , dtype = np .float32 )
584+
585+ sims = 1.0 - dists
586+ out = [
587+ (self ._comp_ids [i ], float (s ))
588+ for i , s in zip (idxs [:k ], sims [:k ])
589+ ]
590+ all_out .append (out )
591+
592+ return all_out [0 ] if single else all_out
593+
594+ # -------------------------
595+ # Exact Tanimoto re-ranking
596+ # -------------------------
597+ all_out : List [List [Tuple [int , float ]]] = []
526598
527- # Re-rank with exact Tanimoto
528- Y = self . _csr [ idxs ]
529- tan = tanimoto_l1_query_vs_block ( q , Y , sum1 = float ( q . sum ()), sumsY = self . _l1 [ idxs ])
530- order = np . argsort ( - tan )[: k ]
599+ for qi , ( idxs , dists ) in enumerate ( batch_results ):
600+ if row_nnz [ qi ] == 0 :
601+ all_out . append ([ ])
602+ continue
531603
532- return [(self ._comp_ids [idxs [i ]], float (tan [i ])) for i in order ]
604+ idxs = np .asarray (idxs , dtype = np .int64 )
605+
606+ q_row = Q [qi ]
607+ Y = self ._csr [idxs ]
608+ tan = tanimoto_l1_query_vs_block (
609+ q_row ,
610+ Y ,
611+ sum1 = float (q_row .sum ()),
612+ sumsY = self ._l1 [idxs ],
613+ )
614+
615+ order = np .argsort (- tan )[:k ]
616+ out = [
617+ (self ._comp_ids [idxs [i ]], float (tan [i ]))
618+ for i in order
619+ ]
620+ all_out .append (out )
621+
622+ return all_out [0 ] if single else all_out
533623
534624 def _normalize_query (self , query_fp ) -> sp .csr_matrix :
535625 """Convert query to single-row CSR and validate."""
0 commit comments