|
| 1 | +import json |
| 2 | +from pathlib import Path |
| 3 | +import numpy as np |
| 4 | +import pandas as pd |
| 5 | +import pytest |
| 6 | +from matchms.importing import load_spectra |
| 7 | +from ms2query import MS2QueryDatabase, MS2QueryLibrary |
| 8 | +from ms2query.library_io import create_new_library, load_created_library |
| 9 | + |
| 10 | + |
| 11 | +TEST_COMP_ID = "ZBSGKPYXQINNGF" # known from your snippet |
| 12 | +EXPECTED_METADATA_SHAPE = (5, 11) |
| 13 | +EXPECTED_METADATA_FIELDS = [ |
| 14 | + "precursor_mz", "ionmode", "smiles", "inchikey", "inchi", "name", |
| 15 | + "charge", "instrument_type", "adduct", "collision_energy", |
| 16 | +] |
| 17 | +EMB_PREFIX = "embedding_index" |
| 18 | +SQLITE_NAME = "ms2query_library.sqlite" |
| 19 | +MANIFEST_NAME = "ms2query_manifest.json" |
| 20 | + |
| 21 | + |
| 22 | +def _data_dir() -> Path: |
| 23 | + return Path(__file__).parent / "test_data" |
| 24 | + |
| 25 | + |
| 26 | +def _paths(): |
| 27 | + data_dir = _data_dir() |
| 28 | + spectra_file = data_dir / "10_spectra.mgf" |
| 29 | + model_path = data_dir / "ms2deepscore_testmodel_v1.pt" |
| 30 | + assert spectra_file.exists(), f"Missing test spectra file: {spectra_file}" |
| 31 | + assert model_path.exists(), f"Missing test model file: {model_path}" |
| 32 | + return spectra_file, model_path |
| 33 | + |
| 34 | + |
| 35 | +def create_lib_from_test_files(tmp_path: Path) -> MS2QueryLibrary: |
| 36 | + """Create a small library from test files and return the loaded MS2QueryLibrary.""" |
| 37 | + spectra_file, model_path = _paths() |
| 38 | + outdir = tmp_path / "ms2query_out" |
| 39 | + outdir.mkdir(parents=True, exist_ok=True) |
| 40 | + |
| 41 | + # Build the library (keep HNSW params small for speed) |
| 42 | + lib_created = create_new_library( |
| 43 | + spectra_files=[str(spectra_file)], |
| 44 | + annotation_files=[], |
| 45 | + output_folder=str(outdir), |
| 46 | + model_path=str(model_path), |
| 47 | + build_embedding_index=True, |
| 48 | + embedding_index_params={"M": 8, "ef_construction": 50, "post_init_ef": 50, "batch_rows": 100_000}, |
| 49 | + compute_embeddings_batch_rows=256, |
| 50 | + ) |
| 51 | + # Sanity |
| 52 | + assert isinstance(lib_created, MS2QueryLibrary) |
| 53 | + assert isinstance(lib_created.db, MS2QueryDatabase) |
| 54 | + |
| 55 | + # Load from disk to mirror real workflow |
| 56 | + lib_loaded = load_created_library(str(outdir)) |
| 57 | + assert isinstance(lib_loaded, MS2QueryLibrary) |
| 58 | + return lib_loaded |
| 59 | + |
| 60 | + |
| 61 | +# -------------------------------------------------------------------- |
| 62 | +# End-to-end smoke test (also re-used by the unit tests below) |
| 63 | +# -------------------------------------------------------------------- |
| 64 | + |
| 65 | +@pytest.mark.filterwarnings("ignore::UserWarning") |
| 66 | +def test_create_and_load_smoke(tmp_path: Path): |
| 67 | + lib = create_lib_from_test_files(tmp_path) |
| 68 | + |
| 69 | + outdir = tmp_path / "ms2query_out" |
| 70 | + db_path = outdir / SQLITE_NAME |
| 71 | + assert db_path.exists(), "MS2Query database file was not created." |
| 72 | + |
| 73 | + # Manifest present & contains basic keys |
| 74 | + manifest_path = outdir / MANIFEST_NAME |
| 75 | + assert manifest_path.exists() |
| 76 | + with open(manifest_path, "r", encoding="utf-8") as f: |
| 77 | + manifest = json.load(f) |
| 78 | + assert manifest.get("sqlite_path") == SQLITE_NAME |
| 79 | + |
| 80 | + # DB content checks |
| 81 | + ms2query_db = lib.db |
| 82 | + meta_df = ms2query_db.metadata_by_comp_id(TEST_COMP_ID) |
| 83 | + assert tuple(meta_df.shape) == EXPECTED_METADATA_SHAPE |
| 84 | + for field in EXPECTED_METADATA_FIELDS: |
| 85 | + assert field in ms2query_db.metadata_fields |
| 86 | + assert field in meta_df.columns |
| 87 | + |
| 88 | + # ANN artifacts (nmslib base + .dat OR legacy .nmslib) |
| 89 | + emb_base = outdir / EMB_PREFIX |
| 90 | + two_file_ok = emb_base.exists() and (emb_base.with_suffix(".dat")).exists() |
| 91 | + legacy_ok = (emb_base.with_suffix(".nmslib")).exists() |
| 92 | + assert two_file_ok or legacy_ok, "Embedding index files missing." |
| 93 | + |
| 94 | + |
| 95 | +# -------------------------------------------------------------------- |
| 96 | +# Unit tests for MS2QueryLibrary methods |
| 97 | +# -------------------------------------------------------------------- |
| 98 | + |
| 99 | +@pytest.mark.filterwarnings("ignore::UserWarning") |
| 100 | +def test_process_and_compute_embeddings(tmp_path: Path): |
| 101 | + lib = create_lib_from_test_files(tmp_path) |
| 102 | + spectra_path, _ = _paths() |
| 103 | + spectra = list(load_spectra(spectra_path)) |
| 104 | + assert len(spectra) > 0 |
| 105 | + |
| 106 | + # process_spectra passthrough for now |
| 107 | + processed = lib.process_spectra(spectra) |
| 108 | + assert isinstance(processed, list) |
| 109 | + assert len(processed) == len(spectra) |
| 110 | + |
| 111 | + # compute_embeddings returns (n, d) float32, d > 0 |
| 112 | + E = lib.compute_embeddings(spectra[:3]) # small batch |
| 113 | + assert isinstance(E, np.ndarray) |
| 114 | + assert E.dtype == np.float32 |
| 115 | + assert E.ndim == 2 and E.shape[0] == 3 and E.shape[1] > 0 |
| 116 | + |
| 117 | + # L2 normalization sanity (norm ~ 1) |
| 118 | + norms = np.linalg.norm(E, axis=1) |
| 119 | + assert np.allclose(norms, 1.0, atol=1e-4) |
| 120 | + |
| 121 | + |
| 122 | +@pytest.mark.filterwarnings("ignore::UserWarning") |
| 123 | +def test_query_embedding_index_returns_dataframe(tmp_path: Path): |
| 124 | + lib = create_lib_from_test_files(tmp_path) |
| 125 | + spectra_path, _ = _paths() |
| 126 | + spectra = list(load_spectra(spectra_path)) |
| 127 | + q = spectra[0] |
| 128 | + |
| 129 | + df = lib.query_embedding_index(q, k=3, ef=40, return_dataframe=True) |
| 130 | + assert isinstance(df, pd.DataFrame) |
| 131 | + assert set(["query_ix", "rank", "spec_id", "score"]).issubset(df.columns) |
| 132 | + assert df.shape[0] >= 1 |
| 133 | + assert isinstance(df["spec_id"].iloc[0], (str, np.str_)) |
| 134 | + |
| 135 | + |
| 136 | +@pytest.mark.filterwarnings("ignore::UserWarning") |
| 137 | +def test_query_spectra_by_spectra_and_compounds(tmp_path: Path): |
| 138 | + lib = create_lib_from_test_files(tmp_path) |
| 139 | + spectra_path, _ = _paths() |
| 140 | + spectra = list(load_spectra(spectra_path)) |
| 141 | + |
| 142 | + # spectra-by-spectra (DataFrame) |
| 143 | + df_s = lib.query_spectra_by_spectra(spectra[:2], k_spectra=5, ef=40) |
| 144 | + assert isinstance(df_s, pd.DataFrame) |
| 145 | + assert set(["query_ix", "rank", "spec_id", "score"]).issubset(df_s.columns) |
| 146 | + assert df_s["query_ix"].nunique() == 2 |
| 147 | + |
| 148 | + # compounds-by-spectra (top-k compounds per query) |
| 149 | + df_c = lib.query_compounds_by_spectra(spectra[:3], k_spectra=20, k_compounds=5, ef=40) |
| 150 | + # Expect columns from metadata + query_ix/rank/score present after merge |
| 151 | + required_cols = set(["query_ix", "spec_id", "score", "inchikey"]).union(EXPECTED_METADATA_FIELDS) |
| 152 | + assert required_cols.issubset(df_c.columns) |
| 153 | + # per query, at most k_compounds rows |
| 154 | + assert (df_c.groupby("query_ix").size() <= 5).all() |
| 155 | + |
| 156 | + |
| 157 | +@pytest.mark.filterwarnings("ignore::UserWarning") |
| 158 | +def test_query_by_spec_ids_uses_db_embeddings(tmp_path: Path): |
| 159 | + lib = create_lib_from_test_files(tmp_path) |
| 160 | + # grab a few ids from the DB |
| 161 | + spec_ids = lib.db.ref_sdb.ids()[:2] |
| 162 | + assert len(spec_ids) >= 1 |
| 163 | + |
| 164 | + df = lib.query_by_spec_ids(spec_ids, k=4, ef=40, return_dataframe=True) |
| 165 | + assert isinstance(df, pd.DataFrame) |
| 166 | + assert set(["query_ix", "rank", "spec_id", "score"]).issubset(df.columns) |
| 167 | + # same number of query_ix values as requested ids |
| 168 | + assert df["query_ix"].nunique() == len(spec_ids) |
| 169 | + |
| 170 | + |
| 171 | +@pytest.mark.filterwarnings("ignore::UserWarning") |
| 172 | +def test_error_paths_missing_model_or_index(tmp_path: Path): |
| 173 | + # Fresh DB only (no index attached) |
| 174 | + lib = create_lib_from_test_files(tmp_path) |
| 175 | + |
| 176 | + # Remove index to test guard rails |
| 177 | + lib.embedding_index = None |
| 178 | + spectra_path, _ = _paths() |
| 179 | + spectra = list(load_spectra(spectra_path)) |
| 180 | + with pytest.raises(RuntimeError, match="EmbeddingIndex is not set"): |
| 181 | + lib.query_embedding_index(spectra[0], k=3) |
| 182 | + |
| 183 | + # New instance without model_path → compute_embeddings should raise when invoked |
| 184 | + lib2 = MS2QueryLibrary(db=lib.db, embedding_index=None, model_path=None) |
| 185 | + with pytest.raises(RuntimeError, match="model_path is not set"): |
| 186 | + lib2.compute_embeddings(spectra[:1]) |
0 commit comments