@@ -53,6 +53,10 @@ def _ensure_model(self):
5353 self ._model = _ms2ds_load_model (self .model_path )
5454 self ._model .eval ()
5555 return self ._model
56+
57+ def _ensure_index (self ):
58+ if self .embedding_index is None :
59+ raise RuntimeError ("EmbeddingIndex is not set. Build or load it before querying." )
5660
5761 # ----------------------------- core API -----------------------------
5862
@@ -111,12 +115,8 @@ def query_embedding_index(
111115 If True, returns a tidy DataFrame with columns:
112116 ['query_ix','rank','spec_id','score']
113117 """
114- if self .embedding_index is None :
115- raise RuntimeError ("EmbeddingIndex is not set. Build or load it before querying." )
116-
117- # Single → list
118- if isinstance (spectra , Spectrum ):
119- spectra = [spectra ]
118+ self ._ensure_index ()
119+ spectra = _ensure_spectra_list (spectra )
120120
121121 # Compute embeddings (L2-normalized)
122122 embeddings = self .compute_embeddings (spectra )
@@ -144,13 +144,57 @@ def query_embedding_index(
144144
145145 def query_compounds_by_spectra (
146146 self ,
147- spectra : Union [Spectrum , Sequence [ Spectrum ] ],
147+ spectra : list [Spectrum ],
148148 * ,
149- k : int = 10 ,
149+ k_spectra : int = 100 ,
150+ k_compounds : int = 10 ,
150151 ef : Optional [int ] = None ,
151- return_dataframe : bool = True ,
152152 ):
153- pass
153+ """
154+ Query the embedding index with spectra, return top-k_compounds per spectrum.
155+
156+ Parameters
157+ ----------
158+ spectra : list[Spectrum]
159+ Query spectra.
160+ k_spectra : int
161+ Number of top spectra to retrieve from the embedding index.
162+ k_compounds : int
163+ Number of top compounds to return per query spectrum.
164+ ef : Optional[int]
165+ nmslib ef parameter (higher = better recall / slower).
166+ """
167+ self ._ensure_index ()
168+ spectra = _ensure_spectra_list (spectra )
169+
170+ if k_compounds > k_spectra :
171+ raise ValueError ("k_compounds cannot be larger than k_spectra" )
172+
173+ # Step1: Query spectral embeddings
174+ results = self .query_embedding_index (spectra , k = k_spectra , ef = ef )
175+
176+ # Pick k_compounds top compounds from the k_spectra hits (if possible)
177+ spec_ids = results .spec_id .values
178+
179+ compounds = self .db .metadata_by_spec_ids ([x for x in spec_ids ]).set_index ("spec_id" )
180+ compounds = compounds .merge (results , on = "spec_id" ).sort_values (["query_ix" , "rank" ])
181+
182+ # Pick no more than k_compounds per query_ix
183+ idx = compounds .groupby (['query_ix' , 'comp_id' ])['score' ].idxmax ()
184+ best_per_pair = compounds .loc [idx ]
185+
186+ # Within each query_ix, keep the top-k by score
187+ df_selected = (
188+ best_per_pair
189+ .sort_values (['query_ix' , 'score' ], ascending = [True , False ])
190+ .groupby ('query_ix' , group_keys = False )
191+ .head (k_compounds )
192+ .reset_index (drop = True )
193+ )
194+
195+ return df_selected
196+
197+
154198 # ----------------------------- helpers / optional glue -----------------------------
155199
156200 def set_embedding_index (self , index : EmbeddingIndex ) -> None :
@@ -185,7 +229,6 @@ def query_by_spec_ids(
185229 if not return_dataframe :
186230 return results_all
187231
188- import pandas as pd
189232 rows = []
190233 for qi , lst in enumerate (results_all ):
191234 for item in lst :
@@ -194,5 +237,14 @@ def query_by_spec_ids(
194237
195238 @staticmethod
196239 def _empty_result_df ():
197- import pandas as pd
198240 return pd .DataFrame (columns = ["query_ix" , "rank" , "spec_id" , "score" ])
241+
242+
243+ # ----------------- helper functions ---------------------
244+
245+ def _ensure_spectra_list (spectra : Union [Spectrum , Sequence [Spectrum ]]) -> List [Spectrum ]:
246+ if isinstance (spectra , Spectrum ):
247+ return [spectra ]
248+ if isinstance (spectra , Sequence ):
249+ return list (spectra )
250+ raise ValueError ("spectra must be a Spectrum or a sequence of Spectrum objects." )
0 commit comments