Skip to content

Commit 4e208cf

Browse files
committed
add merge_fingerprint method
1 parent cf6a805 commit 4e208cf

File tree

3 files changed

+148
-6
lines changed

3 files changed

+148
-6
lines changed

ms2query/data_processing/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .chemistry_utils import compute_morgan_fingerprints, inchikey14_from_full
2-
from .fingerprint_computation import compute_fingerprints_from_smiles
2+
from .fingerprint_computation import compute_fingerprints_from_smiles, merge_fingerprints
33
from .merging_utils import cluster_block, get_merged_spectra
44
from .spectra_processing import compute_spectra_embeddings, normalize_spectrum_sum
55

@@ -11,5 +11,6 @@
1111
"compute_spectra_embeddings",
1212
"get_merged_spectra",
1313
"inchikey14_from_full",
14+
"merge_fingerprints",
1415
"normalize_spectrum_sum",
1516
]

ms2query/data_processing/fingerprint_computation.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from typing import Optional, Sequence, Tuple
12
import numba
23
import numpy as np
34
from numba import typed, types
5+
from numpy.typing import NDArray
46
from rdkit import Chem
57
from tqdm import tqdm
68

@@ -255,6 +257,105 @@ def count_fingerprint_keys(fingerprints):
255257
return unique_keys[order], count_arr[order], first_arr[order]
256258

257259

260+
def merge_fingerprints(
261+
fingerprints: Sequence[Tuple[NDArray[np.integer], NDArray[np.floating]]],
262+
weights: Optional[NDArray[np.floating]] = None,
263+
) -> Tuple[NDArray[np.integer], NDArray[np.floating]]:
264+
"""
265+
Merge multiple sparse Morgan (count/TF-IDF) fingerprints into a single
266+
weighted-average fingerprint.
267+
268+
Parameters
269+
----------
270+
fingerprints :
271+
Sequence of (bits, values) pairs.
272+
- bits: 1D integer array of bit indices (non-zero entries)
273+
- values: 1D float array of TF-IDF (or other) weights,
274+
same length as `bits`.
275+
weights :
276+
Optional 1D array-like of length len(fingerprints) with one weight
277+
per fingerprint. Each fingerprint's values are scaled by its weight,
278+
then the merged fingerprint is normalized by the sum of all weights.
279+
280+
- If None, all fingerprints are weighted equally (weight = 1.0).
281+
282+
Returns
283+
-------
284+
merged_bits, merged_values :
285+
- merged_bits: 1D integer array of unique bit indices
286+
- merged_values: 1D float array of weighted-average values per bit
287+
(sum over all weighted fingerprints, divided by sum(weights)).
288+
"""
289+
n_fps = len(fingerprints)
290+
if n_fps == 0:
291+
# Return empty sparse fingerprint
292+
return (
293+
np.array([], dtype=np.int64),
294+
np.array([], dtype=np.float64),
295+
)
296+
297+
if weights is not None:
298+
w = np.asarray(weights, dtype=np.float64).ravel()
299+
if w.shape[0] != n_fps:
300+
raise ValueError(
301+
f"weights must have length {n_fps}, got {w.shape[0]}"
302+
)
303+
total_weight = float(w.sum())
304+
if total_weight <= 0.0:
305+
raise ValueError("Sum of weights must be positive.")
306+
else:
307+
# Equal weighting
308+
w = None
309+
total_weight = float(n_fps)
310+
311+
# Concatenate all indices and (weighted) values
312+
bits_list = []
313+
vals_list = []
314+
315+
for i, (bits, vals) in enumerate(fingerprints):
316+
bits = np.asarray(bits)
317+
vals = np.asarray(vals, dtype=np.float64)
318+
319+
if bits.shape[0] != vals.shape[0]:
320+
raise ValueError(
321+
f"Fingerprint {i}: bits and values must have same length, "
322+
f"got {bits.shape[0]} and {vals.shape[0]}"
323+
)
324+
325+
if w is not None:
326+
vals = vals * w[i]
327+
328+
bits_list.append(bits)
329+
vals_list.append(vals)
330+
331+
if not bits_list:
332+
return (
333+
np.array([], dtype=np.int64),
334+
np.array([], dtype=np.float64),
335+
)
336+
337+
all_bits = np.concatenate(bits_list)
338+
all_vals = np.concatenate(vals_list)
339+
340+
if all_bits.size == 0:
341+
return (
342+
np.array([], dtype=np.int64),
343+
np.array([], dtype=np.float64),
344+
)
345+
346+
# Group by bit index and sum weighted values
347+
unique_bits, inverse = np.unique(all_bits, return_inverse=True)
348+
summed_vals = np.bincount(inverse, weights=all_vals)
349+
350+
# Weighted average: divide by sum of weights
351+
avg_vals = summed_vals / total_weight
352+
353+
# Keep dtypes reasonably tight
354+
merged_bits = unique_bits.astype(all_bits.dtype, copy=False)
355+
merged_vals = avg_vals.astype(np.float32, copy=False)
356+
357+
return merged_bits, merged_vals
358+
258359
### ------------------------
259360
### Bit Scaling and Weighing
260361
### ------------------------

ms2query/ms2query_library.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
import pandas as pd
55
from matchms import Spectrum
66
from ms2deepscore.models import load_model as _ms2ds_load_model
7+
from sklearn.metrics.pairwise import cosine_similarity
78
from ms2query import MS2QueryDatabase
8-
from ms2query.data_processing import compute_spectra_embeddings
9+
from ms2query.data_processing import compute_spectra_embeddings, merge_fingerprints
910
from ms2query.database import EmbeddingIndex, FingerprintSparseIndex
1011

1112

@@ -31,6 +32,7 @@ class MS2QueryLibrary:
3132
db: MS2QueryDatabase
3233
embedding_index: Optional[EmbeddingIndex] = None
3334
fingerprint_index: Optional[FingerprintSparseIndex] = None # for now: reference spectra only
35+
large_scale_fingerprint_index: Optional[FingerprintSparseIndex] = None # for large body of reference compounds
3436
model_path: Optional[str] = None
3537

3638
# internal: whether to apply spectrum normalization (sum=1) before embedding
@@ -319,13 +321,51 @@ def analogue_search(
319321
.set_index("spec_id")
320322
)
321323

322-
smiles = analogue_compounds["smiles"].tolist()
324+
analogue_smiles = analogue_compounds["smiles"].tolist()
323325

324326
# Step 3: fingerprint-based compound search
325327
top_compounds = self.query_compounds_by_compounds(
326-
smiles, k_compounds=k_compounds
327-
)
328-
return top_compounds
328+
smiles=analogue_smiles
329+
).set_index("query_ix")
330+
331+
# Step 4: for each query, pick the best matching spectrum among all spectra
332+
fingerprints_merged = []
333+
weighted_average_scores = []
334+
embeddings_queries = self.compute_embeddings(spectra) # TODO: this is now done twice! in step 1 and here
335+
for i in range(len(analogue_smiles)):
336+
comp_ids = top_compounds.loc[i].comp_id.to_list()
337+
338+
# Get chemically closest compounds
339+
spec_ids_all = []
340+
spec_ids_selected = []
341+
embeddings_selected = []
342+
343+
all_spec_ids = self.db.spec_ids_by_comp_ids(comp_ids).set_index("comp_id")
344+
for comp_id in comp_ids:
345+
new_spec_ids = all_spec_ids.loc[comp_id].spec_id.to_list()
346+
347+
# Get most similar embedding from one of the top-10 compounds
348+
embs = self.db.ref_sdb.get_embeddings(new_spec_ids)
349+
similarities = cosine_similarity(embs[1], embeddings_queries[i].reshape(1, -1))
350+
max_id = np.argmax(similarities)
351+
spec_ids_selected.append(embs[0][max_id])
352+
embeddings_selected.append(embs[1][max_id])
353+
spec_ids_all.extend(new_spec_ids)
354+
355+
top1_top10_similarities = cosine_similarity(embeddings_selected, embeddings_queries[i].reshape(1, -1))
356+
fingerprints = self.db.ref_cdb.get_fingerprints(comp_ids)
357+
fingerprints_merged.append(merge_fingerprints(fingerprints, weights=top1_top10_similarities))
358+
weighted_average_scores.append(np.sum(top1_top10_similarities ** 2) / np.sum(top1_top10_similarities))
359+
if self.large_scale_fingerprint_index:
360+
analogue_predictions = self.large_scale_fingerprint_index.query(fingerprints_merged, k=k_compounds)
361+
elif self.fingerprint_index:
362+
analogue_predictions = self.fingerprint_index.query(fingerprints_merged, k=k_compounds)
363+
else:
364+
raise RuntimeError("No fingerprint index is set. Build or load it before querying.")
365+
return pd.DataFrame({
366+
"analogue_predictions": analogue_predictions,
367+
"weighted_average_scores": weighted_average_scores
368+
})
329369

330370
# ------------------------------------------------------------------
331371
# Helpers / glue

0 commit comments

Comments
 (0)