@@ -323,28 +323,63 @@ def _create_hnsw_index(
323323
324324 def query (
325325 self ,
326- vector : np .ndarray ,
326+ vectors : np .ndarray ,
327327 k : int = 10 ,
328328 ef : Optional [int ] = None ,
329- ) -> List [Tuple [str , float ]]:
329+ num_threads : int = 0 ,
330+ ) -> List [Tuple [str , float ]] | List [List [Tuple [str , float ]]]:
330331 """
331332 Query for k nearest neighbors.
332333
333- Returns list of (spec_id, similarity) tuples.
334+ Parameters
335+ ----------
336+ vectors : np.ndarray
337+ Either a single vector of shape (dim,) or a batch of shape (N, dim).
338+ k : int
339+ Number of neighbors.
340+ ef : Optional[int]
341+ Optional per-query ef parameter for HNSW.
342+ num_threads : int
343+ Number of threads to use inside nmslib (0 = library default).
344+
345+ Returns
346+ -------
347+ Union[List[Tuple[str, float]], List[List[Tuple[str, float]]]]
348+ - If a single vector is given, returns a list of (spec_id, similarity).
349+ - If a batch is given, returns a list (per query) of such lists.
334350 """
335351 if self ._index is None :
336352 raise RuntimeError ("Index not built or loaded." )
337353
338- v = np .asarray (vector , dtype = np .float32 ).reshape (1 , - 1 )
339- if v .shape [1 ] != self .dim :
340- raise ValueError (f"Query must have dim={ self .dim } " )
354+ X = np .asarray (vectors , dtype = np .float32 )
355+
356+ single = False
357+ if X .ndim == 1 :
358+ # Single query vector: (dim,) -> (1, dim)
359+ if X .size != self .dim :
360+ raise ValueError (f"Query must have dim={ self .dim } " )
361+ X = X .reshape (1 , - 1 )
362+ single = True
363+ elif X .ndim == 2 :
364+ if X .shape [1 ] != self .dim :
365+ raise ValueError (f"Expected shape (N, { self .dim } ), got { X .shape } " )
366+ else :
367+ raise ValueError ("vectors must be 1D or 2D array." )
341368
342369 if ef is not None :
343370 self ._index .setQueryTimeParams ({"ef" : ef })
344371
345- idxs , dists = self ._index .knnQueryBatch (v , k = k )[0 ]
346- sims = 1.0 - np .asarray (dists , dtype = np .float32 ) # cosine distance -> similarity
347- return [(str (self ._ids [i ]), float (sims [j ])) for j , i in enumerate (idxs )]
372+ batch_results = self ._index .knnQueryBatch (X , k = k , num_threads = num_threads )
373+
374+ all_out : List [List [Tuple [str , float ]]] = []
375+ for idxs , dists in batch_results :
376+ idxs = np .asarray (idxs , dtype = np .int64 )
377+ dists = np .asarray (dists , dtype = np .float32 )
378+ sims = 1.0 - dists # cosine distance -> similarity
379+ out = [(str (self ._ids [i ]), float (s )) for i , s in zip (idxs , sims )]
380+ all_out .append (out )
381+
382+ return all_out [0 ] if single else all_out
348383
349384 def save_index (self , path_prefix : str ) -> None :
350385 if self ._index is None :
@@ -450,51 +485,141 @@ def build_index(
450485
451486 def query (
452487 self ,
453- query_fp : Tuple [np .ndarray , np .ndarray ] | sp .csr_matrix ,
488+ query_fp : (
489+ Tuple [np .ndarray , np .ndarray ]
490+ | sp .csr_matrix
491+ | Sequence [Tuple [np .ndarray , np .ndarray ]]
492+ ),
454493 k : int = 10 ,
455494 * ,
456495 ef : Optional [int ] = None ,
457496 re_rank : bool = True ,
458497 candidate_multiplier : int = 5 ,
459- ) -> List [Tuple [int , float ]]:
498+ num_threads : int = 0 ,
499+ ) -> List [Tuple [int , float ]] | List [List [Tuple [int , float ]]]:
460500 """
461501 Query for k nearest neighbors.
462502
463503 Parameters
464504 ----------
465- query_fp : (indices, values) tuple or single-row CSR
466- k : Number of results
467- re_rank : Use exact Tanimoto re-ranking
468- candidate_multiplier : Fetch k * multiplier candidates for re-ranking
469-
470- Returns list of (comp_id, similarity) tuples.
505+ query_fp :
506+ - Single query:
507+ * (indices, values) tuple
508+ * single-row CSR of shape (1, dim)
509+ - Batched queries:
510+ * CSR of shape (N, dim)
511+ * Sequence of (indices, values) tuples
512+ k : int
513+ Number of results per query.
514+ re_rank : bool
515+ Use exact Tanimoto re-ranking.
516+ candidate_multiplier : int
517+ Fetch k * multiplier candidates for re-ranking.
518+ num_threads : int
519+ Number of threads to use inside nmslib (0 = library default).
520+
521+ Returns
522+ -------
523+ Union[List[Tuple[int, float]], List[List[Tuple[int, float]]]]
524+ - For a single query, returns a list of (comp_id, similarity).
525+ - For multiple queries, returns a list (per query) of such lists.
471526 """
472527 if self ._index is None :
473528 raise RuntimeError ("Index not built or loaded." )
474529
475- q = self ._normalize_query (query_fp )
476- if q .nnz == 0 :
477- return []
530+ # -------------------------
531+ # Normalize input to CSR
532+ # -------------------------
533+ single = False
534+
535+ if isinstance (query_fp , sp .csr_matrix ):
536+ Q = query_fp .astype (np .float32 , copy = False )
537+ if Q .shape [1 ] != self .dim :
538+ raise ValueError (f"CSR query must have shape (N, { self .dim } )" )
539+ single = Q .shape [0 ] == 1
540+
541+ elif isinstance (query_fp , tuple ):
542+ # Single (indices, values)
543+ Q = csr_row_from_tuple (query_fp , dim = self .dim )
544+ single = True
545+
546+ else :
547+ # Assume sequence of (indices, values) tuples -> batched queries
548+ Q = tuples_to_csr (query_fp , dim = self .dim )
549+ single = Q .shape [0 ] == 1
550+
551+ if (Q .data < 0 ).any ():
552+ raise ValueError ("Query must be non-negative for Tanimoto." )
553+
554+ # Handle completely empty queries quickly
555+ row_nnz = Q .indptr [1 :] - Q .indptr [:- 1 ]
556+ if row_nnz .sum () == 0 :
557+ if single :
558+ return []
559+ return [[] for _ in range (Q .shape [0 ])]
478560
479561 if ef is not None :
480562 self ._index .setQueryTimeParams ({"ef" : ef })
481563
482564 fetch = max (k , k * candidate_multiplier )
483- idxs , dists = self ._index .knnQueryBatch (q , k = fetch )[0 ]
484- idxs = np .asarray (idxs , dtype = np .int64 )
485- dists = np .asarray (dists , dtype = np .float32 )
486565
487- # Without re-ranking, return cosine similarities
566+ # -------------------------
567+ # ANN search for all queries
568+ # -------------------------
569+ batch_results = self ._index .knnQueryBatch (Q , k = fetch , num_threads = num_threads )
570+
571+ # -------------------------
572+ # No re-ranking: cosine sims only
573+ # -------------------------
488574 if not re_rank or self ._csr is None or self ._l1 is None :
489- sims = 1.0 - dists
490- return [(int (self ._comp_ids [i ]), float (s )) for i , s in zip (idxs [:k ], sims [:k ])]
575+ all_out : List [List [Tuple [int , float ]]] = []
491576
492- # Re-rank with exact Tanimoto
493- Y = self . _csr [ idxs ]
494- tan = tanimoto_l1_query_vs_block ( q , Y , sum1 = float ( q . sum ()), sumsY = self . _l1 [ idxs ])
495- order = np . argsort ( - tan )[: k ]
577+ for qi , ( idxs , dists ) in enumerate ( batch_results ):
578+ if row_nnz [ qi ] == 0 :
579+ all_out . append ([ ])
580+ continue
496581
497- return [(int (self ._comp_ids [idxs [i ]]), float (tan [i ])) for i in order ]
582+ idxs = np .asarray (idxs , dtype = np .int64 )
583+ dists = np .asarray (dists , dtype = np .float32 )
584+
585+ sims = 1.0 - dists
586+ out = [
587+ (self ._comp_ids [i ], float (s ))
588+ for i , s in zip (idxs [:k ], sims [:k ])
589+ ]
590+ all_out .append (out )
591+
592+ return all_out [0 ] if single else all_out
593+
594+ # -------------------------
595+ # Exact Tanimoto re-ranking
596+ # -------------------------
597+ all_out : List [List [Tuple [int , float ]]] = []
598+
599+ for qi , (idxs , dists ) in enumerate (batch_results ):
600+ if row_nnz [qi ] == 0 :
601+ all_out .append ([])
602+ continue
603+
604+ idxs = np .asarray (idxs , dtype = np .int64 )
605+
606+ q_row = Q [qi ]
607+ Y = self ._csr [idxs ]
608+ tan = tanimoto_l1_query_vs_block (
609+ q_row ,
610+ Y ,
611+ sum1 = float (q_row .sum ()),
612+ sumsY = self ._l1 [idxs ],
613+ )
614+
615+ order = np .argsort (- tan )[:k ]
616+ out = [
617+ (self ._comp_ids [idxs [i ]], float (tan [i ]))
618+ for i in order
619+ ]
620+ all_out .append (out )
621+
622+ return all_out [0 ] if single else all_out
498623
499624 def _normalize_query (self , query_fp ) -> sp .csr_matrix :
500625 """Convert query to single-row CSR and validate."""
@@ -602,10 +727,11 @@ def save_index(self, path_prefix: str) -> None:
602727 if self ._index is None :
603728 raise RuntimeError ("Index not built." )
604729
605- self ._index .saveIndex (f"{ path_prefix } .nmslib" )
730+ # Also save data so that load_index(..., load_data=True) works
731+ self ._index .saveIndex (f"{ path_prefix } .nmslib" , save_data = True )
606732 np .save (f"{ path_prefix } .ids.npy" , self ._comp_ids )
607733
608- meta = {** self ._meta , "dim" : self .dim , "space" : self .space }
734+ meta = {** self ._meta , "dim" : int ( self .dim ) , "space" : str ( self .space ) }
609735 with open (f"{ path_prefix } .meta.json" , "w" ) as f :
610736 json .dump (meta , f )
611737
0 commit comments