Skip to content

Commit c64e352

Browse files
committed
refactoring and new query method
1 parent dec423d commit c64e352

File tree

1 file changed

+64
-12
lines changed

1 file changed

+64
-12
lines changed

ms2query/ms2query_library.py

Lines changed: 64 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ def _ensure_model(self):
5353
self._model = _ms2ds_load_model(self.model_path)
5454
self._model.eval()
5555
return self._model
56+
57+
def _ensure_index(self):
58+
if self.embedding_index is None:
59+
raise RuntimeError("EmbeddingIndex is not set. Build or load it before querying.")
5660

5761
# ----------------------------- core API -----------------------------
5862

@@ -111,12 +115,8 @@ def query_embedding_index(
111115
If True, returns a tidy DataFrame with columns:
112116
['query_ix','rank','spec_id','score']
113117
"""
114-
if self.embedding_index is None:
115-
raise RuntimeError("EmbeddingIndex is not set. Build or load it before querying.")
116-
117-
# Single → list
118-
if isinstance(spectra, Spectrum):
119-
spectra = [spectra]
118+
self._ensure_index()
119+
spectra = _ensure_spectra_list(spectra)
120120

121121
# Compute embeddings (L2-normalized)
122122
embeddings = self.compute_embeddings(spectra)
@@ -144,13 +144,57 @@ def query_embedding_index(
144144

145145
def query_compounds_by_spectra(
146146
self,
147-
spectra: Union[Spectrum, Sequence[Spectrum]],
147+
spectra: list[Spectrum],
148148
*,
149-
k: int = 10,
149+
k_spectra: int = 100,
150+
k_compounds: int = 10,
150151
ef: Optional[int] = None,
151-
return_dataframe: bool = True,
152152
):
153-
pass
153+
"""
154+
Query the embedding index with spectra, return top-k_compounds per spectrum.
155+
156+
Parameters
157+
----------
158+
spectra : list[Spectrum]
159+
Query spectra.
160+
k_spectra : int
161+
Number of top spectra to retrieve from the embedding index.
162+
k_compounds : int
163+
Number of top compounds to return per query spectrum.
164+
ef : Optional[int]
165+
nmslib ef parameter (higher = better recall / slower).
166+
"""
167+
self._ensure_index()
168+
spectra = _ensure_spectra_list(spectra)
169+
170+
if k_compounds > k_spectra:
171+
raise ValueError("k_compounds cannot be larger than k_spectra")
172+
173+
# Step1: Query spectral embeddings
174+
results = self.query_embedding_index(spectra, k=k_spectra, ef=ef)
175+
176+
# Pick k_compounds top compounds from the k_spectra hits (if possible)
177+
spec_ids = results.spec_id.values
178+
179+
compounds = self.db.metadata_by_spec_ids([x for x in spec_ids]).set_index("spec_id")
180+
compounds = compounds.merge(results, on="spec_id").sort_values(["query_ix", "rank"])
181+
182+
# Pick no more than k_compounds per query_ix
183+
idx = compounds.groupby(['query_ix', 'comp_id'])['score'].idxmax()
184+
best_per_pair = compounds.loc[idx]
185+
186+
# Within each query_ix, keep the top-k by score
187+
df_selected = (
188+
best_per_pair
189+
.sort_values(['query_ix', 'score'], ascending=[True, False])
190+
.groupby('query_ix', group_keys=False)
191+
.head(k_compounds)
192+
.reset_index(drop=True)
193+
)
194+
195+
return df_selected
196+
197+
154198
# ----------------------------- helpers / optional glue -----------------------------
155199

156200
def set_embedding_index(self, index: EmbeddingIndex) -> None:
@@ -185,7 +229,6 @@ def query_by_spec_ids(
185229
if not return_dataframe:
186230
return results_all
187231

188-
import pandas as pd
189232
rows = []
190233
for qi, lst in enumerate(results_all):
191234
for item in lst:
@@ -194,5 +237,14 @@ def query_by_spec_ids(
194237

195238
@staticmethod
196239
def _empty_result_df():
197-
import pandas as pd
198240
return pd.DataFrame(columns=["query_ix", "rank", "spec_id", "score"])
241+
242+
243+
# ----------------- helper functions ---------------------
244+
245+
def _ensure_spectra_list(spectra: Union[Spectrum, Sequence[Spectrum]]) -> List[Spectrum]:
246+
if isinstance(spectra, Spectrum):
247+
return [spectra]
248+
if isinstance(spectra, Sequence):
249+
return list(spectra)
250+
raise ValueError("spectra must be a Spectrum or a sequence of Spectrum objects.")

0 commit comments

Comments
 (0)