Skip to content

Commit 48e9dea

Browse files
committed
Merge branch 'main' into add_benchmarking_method
2 parents fd688c4 + 0ca72ef commit 48e9dea

13 files changed

Lines changed: 1086 additions & 288 deletions

ms2query/data_processing/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .chemistry_utils import compute_morgan_fingerprints, inchikey14_from_full
2-
from .fingerprint_computation import compute_fingerprints_from_smiles
2+
from .fingerprint_computation import compute_fingerprints_from_smiles, merge_fingerprints
33
from .merging_utils import cluster_block, get_merged_spectra
44
from .spectra_processing import compute_spectra_embeddings, normalize_spectrum_sum
55

@@ -11,5 +11,6 @@
1111
"compute_spectra_embeddings",
1212
"get_merged_spectra",
1313
"inchikey14_from_full",
14+
"merge_fingerprints",
1415
"normalize_spectrum_sum",
1516
]

ms2query/data_processing/fingerprint_computation.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from typing import Optional, Sequence, Tuple
12
import numba
23
import numpy as np
34
from numba import typed, types
5+
from numpy.typing import NDArray
46
from rdkit import Chem
57
from tqdm import tqdm
68

@@ -255,6 +257,105 @@ def count_fingerprint_keys(fingerprints):
255257
return unique_keys[order], count_arr[order], first_arr[order]
256258

257259

260+
def merge_fingerprints(
261+
fingerprints: Sequence[Tuple[NDArray[np.integer], NDArray[np.floating]]],
262+
weights: Optional[NDArray[np.floating]] = None,
263+
) -> Tuple[NDArray[np.integer], NDArray[np.floating]]:
264+
"""
265+
Merge multiple sparse Morgan (count/TF-IDF) fingerprints into a single
266+
weighted-average fingerprint.
267+
268+
Parameters
269+
----------
270+
fingerprints :
271+
Sequence of (bits, values) pairs.
272+
- bits: 1D integer array of bit indices (non-zero entries)
273+
- values: 1D float array of TF-IDF (or other) weights,
274+
same length as `bits`.
275+
weights :
276+
Optional 1D array-like of length len(fingerprints) with one weight
277+
per fingerprint. Each fingerprint's values are scaled by its weight,
278+
then the merged fingerprint is normalized by the sum of all weights.
279+
280+
- If None, all fingerprints are weighted equally (weight = 1.0).
281+
282+
Returns
283+
-------
284+
merged_bits, merged_values :
285+
- merged_bits: 1D integer array of unique bit indices
286+
- merged_values: 1D float array of weighted-average values per bit
287+
(sum over all weighted fingerprints, divided by sum(weights)).
288+
"""
289+
n_fps = len(fingerprints)
290+
if n_fps == 0:
291+
# Return empty sparse fingerprint
292+
return (
293+
np.array([], dtype=np.int64),
294+
np.array([], dtype=np.float64),
295+
)
296+
297+
if weights is not None:
298+
w = np.asarray(weights, dtype=np.float64).ravel()
299+
if w.shape[0] != n_fps:
300+
raise ValueError(
301+
f"weights must have length {n_fps}, got {w.shape[0]}"
302+
)
303+
total_weight = float(w.sum())
304+
if total_weight <= 0.0:
305+
raise ValueError("Sum of weights must be positive.")
306+
else:
307+
# Equal weighting
308+
w = None
309+
total_weight = float(n_fps)
310+
311+
# Concatenate all indices and (weighted) values
312+
bits_list = []
313+
vals_list = []
314+
315+
for i, (bits, vals) in enumerate(fingerprints):
316+
bits = np.asarray(bits)
317+
vals = np.asarray(vals, dtype=np.float64)
318+
319+
if bits.shape[0] != vals.shape[0]:
320+
raise ValueError(
321+
f"Fingerprint {i}: bits and values must have same length, "
322+
f"got {bits.shape[0]} and {vals.shape[0]}"
323+
)
324+
325+
if w is not None:
326+
vals = vals * w[i]
327+
328+
bits_list.append(bits)
329+
vals_list.append(vals)
330+
331+
if not bits_list:
332+
return (
333+
np.array([], dtype=np.int64),
334+
np.array([], dtype=np.float64),
335+
)
336+
337+
all_bits = np.concatenate(bits_list)
338+
all_vals = np.concatenate(vals_list)
339+
340+
if all_bits.size == 0:
341+
return (
342+
np.array([], dtype=np.int64),
343+
np.array([], dtype=np.float64),
344+
)
345+
346+
# Group by bit index and sum weighted values
347+
unique_bits, inverse = np.unique(all_bits, return_inverse=True)
348+
summed_vals = np.bincount(inverse, weights=all_vals)
349+
350+
# Weighted average: divide by sum of weights
351+
avg_vals = summed_vals / total_weight
352+
353+
# Keep dtypes reasonably tight
354+
merged_bits = unique_bits.astype(all_bits.dtype, copy=False)
355+
merged_vals = avg_vals.astype(np.float32, copy=False)
356+
357+
return merged_bits, merged_vals
358+
258359
### ------------------------
259360
### Bit Scaling and Weighing
260361
### ------------------------

ms2query/database/ann_vector_index.py

Lines changed: 159 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -323,28 +323,63 @@ def _create_hnsw_index(
323323

324324
def query(
325325
self,
326-
vector: np.ndarray,
326+
vectors: np.ndarray,
327327
k: int = 10,
328328
ef: Optional[int] = None,
329-
) -> List[Tuple[str, float]]:
329+
num_threads: int = 0,
330+
) -> List[Tuple[str, float]] | List[List[Tuple[str, float]]]:
330331
"""
331332
Query for k nearest neighbors.
332333
333-
Returns list of (spec_id, similarity) tuples.
334+
Parameters
335+
----------
336+
vectors : np.ndarray
337+
Either a single vector of shape (dim,) or a batch of shape (N, dim).
338+
k : int
339+
Number of neighbors.
340+
ef : Optional[int]
341+
Optional per-query ef parameter for HNSW.
342+
num_threads : int
343+
Number of threads to use inside nmslib (0 = library default).
344+
345+
Returns
346+
-------
347+
Union[List[Tuple[str, float]], List[List[Tuple[str, float]]]]
348+
- If a single vector is given, returns a list of (spec_id, similarity).
349+
- If a batch is given, returns a list (per query) of such lists.
334350
"""
335351
if self._index is None:
336352
raise RuntimeError("Index not built or loaded.")
337353

338-
v = np.asarray(vector, dtype=np.float32).reshape(1, -1)
339-
if v.shape[1] != self.dim:
340-
raise ValueError(f"Query must have dim={self.dim}")
354+
X = np.asarray(vectors, dtype=np.float32)
355+
356+
single = False
357+
if X.ndim == 1:
358+
# Single query vector: (dim,) -> (1, dim)
359+
if X.size != self.dim:
360+
raise ValueError(f"Query must have dim={self.dim}")
361+
X = X.reshape(1, -1)
362+
single = True
363+
elif X.ndim == 2:
364+
if X.shape[1] != self.dim:
365+
raise ValueError(f"Expected shape (N, {self.dim}), got {X.shape}")
366+
else:
367+
raise ValueError("vectors must be 1D or 2D array.")
341368

342369
if ef is not None:
343370
self._index.setQueryTimeParams({"ef": ef})
344371

345-
idxs, dists = self._index.knnQueryBatch(v, k=k)[0]
346-
sims = 1.0 - np.asarray(dists, dtype=np.float32) # cosine distance -> similarity
347-
return [(str(self._ids[i]), float(sims[j])) for j, i in enumerate(idxs)]
372+
batch_results = self._index.knnQueryBatch(X, k=k, num_threads=num_threads)
373+
374+
all_out: List[List[Tuple[str, float]]] = []
375+
for idxs, dists in batch_results:
376+
idxs = np.asarray(idxs, dtype=np.int64)
377+
dists = np.asarray(dists, dtype=np.float32)
378+
sims = 1.0 - dists # cosine distance -> similarity
379+
out = [(str(self._ids[i]), float(s)) for i, s in zip(idxs, sims)]
380+
all_out.append(out)
381+
382+
return all_out[0] if single else all_out
348383

349384
def save_index(self, path_prefix: str) -> None:
350385
if self._index is None:
@@ -450,51 +485,141 @@ def build_index(
450485

451486
def query(
452487
self,
453-
query_fp: Tuple[np.ndarray, np.ndarray] | sp.csr_matrix,
488+
query_fp: (
489+
Tuple[np.ndarray, np.ndarray]
490+
| sp.csr_matrix
491+
| Sequence[Tuple[np.ndarray, np.ndarray]]
492+
),
454493
k: int = 10,
455494
*,
456495
ef: Optional[int] = None,
457496
re_rank: bool = True,
458497
candidate_multiplier: int = 5,
459-
) -> List[Tuple[int, float]]:
498+
num_threads: int = 0,
499+
) -> List[Tuple[int, float]] | List[List[Tuple[int, float]]]:
460500
"""
461501
Query for k nearest neighbors.
462502
463503
Parameters
464504
----------
465-
query_fp : (indices, values) tuple or single-row CSR
466-
k : Number of results
467-
re_rank : Use exact Tanimoto re-ranking
468-
candidate_multiplier : Fetch k * multiplier candidates for re-ranking
469-
470-
Returns list of (comp_id, similarity) tuples.
505+
query_fp :
506+
- Single query:
507+
* (indices, values) tuple
508+
* single-row CSR of shape (1, dim)
509+
- Batched queries:
510+
* CSR of shape (N, dim)
511+
* Sequence of (indices, values) tuples
512+
k : int
513+
Number of results per query.
514+
re_rank : bool
515+
Use exact Tanimoto re-ranking.
516+
candidate_multiplier : int
517+
Fetch k * multiplier candidates for re-ranking.
518+
num_threads : int
519+
Number of threads to use inside nmslib (0 = library default).
520+
521+
Returns
522+
-------
523+
Union[List[Tuple[int, float]], List[List[Tuple[int, float]]]]
524+
- For a single query, returns a list of (comp_id, similarity).
525+
- For multiple queries, returns a list (per query) of such lists.
471526
"""
472527
if self._index is None:
473528
raise RuntimeError("Index not built or loaded.")
474529

475-
q = self._normalize_query(query_fp)
476-
if q.nnz == 0:
477-
return []
530+
# -------------------------
531+
# Normalize input to CSR
532+
# -------------------------
533+
single = False
534+
535+
if isinstance(query_fp, sp.csr_matrix):
536+
Q = query_fp.astype(np.float32, copy=False)
537+
if Q.shape[1] != self.dim:
538+
raise ValueError(f"CSR query must have shape (N, {self.dim})")
539+
single = Q.shape[0] == 1
540+
541+
elif isinstance(query_fp, tuple):
542+
# Single (indices, values)
543+
Q = csr_row_from_tuple(query_fp, dim=self.dim)
544+
single = True
545+
546+
else:
547+
# Assume sequence of (indices, values) tuples -> batched queries
548+
Q = tuples_to_csr(query_fp, dim=self.dim)
549+
single = Q.shape[0] == 1
550+
551+
if (Q.data < 0).any():
552+
raise ValueError("Query must be non-negative for Tanimoto.")
553+
554+
# Handle completely empty queries quickly
555+
row_nnz = Q.indptr[1:] - Q.indptr[:-1]
556+
if row_nnz.sum() == 0:
557+
if single:
558+
return []
559+
return [[] for _ in range(Q.shape[0])]
478560

479561
if ef is not None:
480562
self._index.setQueryTimeParams({"ef": ef})
481563

482564
fetch = max(k, k * candidate_multiplier)
483-
idxs, dists = self._index.knnQueryBatch(q, k=fetch)[0]
484-
idxs = np.asarray(idxs, dtype=np.int64)
485-
dists = np.asarray(dists, dtype=np.float32)
486565

487-
# Without re-ranking, return cosine similarities
566+
# -------------------------
567+
# ANN search for all queries
568+
# -------------------------
569+
batch_results = self._index.knnQueryBatch(Q, k=fetch, num_threads=num_threads)
570+
571+
# -------------------------
572+
# No re-ranking: cosine sims only
573+
# -------------------------
488574
if not re_rank or self._csr is None or self._l1 is None:
489-
sims = 1.0 - dists
490-
return [(int(self._comp_ids[i]), float(s)) for i, s in zip(idxs[:k], sims[:k])]
575+
all_out: List[List[Tuple[int, float]]] = []
491576

492-
# Re-rank with exact Tanimoto
493-
Y = self._csr[idxs]
494-
tan = tanimoto_l1_query_vs_block(q, Y, sum1=float(q.sum()), sumsY=self._l1[idxs])
495-
order = np.argsort(-tan)[:k]
577+
for qi, (idxs, dists) in enumerate(batch_results):
578+
if row_nnz[qi] == 0:
579+
all_out.append([])
580+
continue
496581

497-
return [(int(self._comp_ids[idxs[i]]), float(tan[i])) for i in order]
582+
idxs = np.asarray(idxs, dtype=np.int64)
583+
dists = np.asarray(dists, dtype=np.float32)
584+
585+
sims = 1.0 - dists
586+
out = [
587+
(self._comp_ids[i], float(s))
588+
for i, s in zip(idxs[:k], sims[:k])
589+
]
590+
all_out.append(out)
591+
592+
return all_out[0] if single else all_out
593+
594+
# -------------------------
595+
# Exact Tanimoto re-ranking
596+
# -------------------------
597+
all_out: List[List[Tuple[int, float]]] = []
598+
599+
for qi, (idxs, dists) in enumerate(batch_results):
600+
if row_nnz[qi] == 0:
601+
all_out.append([])
602+
continue
603+
604+
idxs = np.asarray(idxs, dtype=np.int64)
605+
606+
q_row = Q[qi]
607+
Y = self._csr[idxs]
608+
tan = tanimoto_l1_query_vs_block(
609+
q_row,
610+
Y,
611+
sum1=float(q_row.sum()),
612+
sumsY=self._l1[idxs],
613+
)
614+
615+
order = np.argsort(-tan)[:k]
616+
out = [
617+
(self._comp_ids[idxs[i]], float(tan[i]))
618+
for i in order
619+
]
620+
all_out.append(out)
621+
622+
return all_out[0] if single else all_out
498623

499624
def _normalize_query(self, query_fp) -> sp.csr_matrix:
500625
"""Convert query to single-row CSR and validate."""
@@ -602,10 +727,11 @@ def save_index(self, path_prefix: str) -> None:
602727
if self._index is None:
603728
raise RuntimeError("Index not built.")
604729

605-
self._index.saveIndex(f"{path_prefix}.nmslib")
730+
# Also save data so that load_index(..., load_data=True) works
731+
self._index.saveIndex(f"{path_prefix}.nmslib", save_data=True)
606732
np.save(f"{path_prefix}.ids.npy", self._comp_ids)
607733

608-
meta = {**self._meta, "dim": self.dim, "space": self.space}
734+
meta = {**self._meta, "dim": int(self.dim), "space": str(self.space)}
609735
with open(f"{path_prefix}.meta.json", "w") as f:
610736
json.dump(meta, f)
611737

0 commit comments

Comments
 (0)