Skip to content

Commit 73455a1

Browse files
committed
Added save and load options for AnnotatedSpectrumSet
1 parent 83c6c7d commit 73455a1

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

ms2query/benchmarking/AnnotatedSpectrumSet.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import os
12
from collections import defaultdict
23
from typing import Iterable, List, Mapping, Optional, Sequence
34
from matchms import Spectrum
5+
from matchms.exporting import save_spectra
6+
from matchms.importing import load_spectra
47
from ms2deepscore.models import SiameseSpectralModel
58
from tqdm import tqdm
69
from ms2query.benchmarking.Embeddings import Embeddings
@@ -144,3 +147,23 @@ def __str__(self):
144147
with_embeddings = "with embeddings"
145148

146149
return f"{len(self)} spectra for {len(self.inchikeys)} inchikeys {with_embeddings}"
150+
151+
def save(self, save_file: str) -> None:
152+
"""Save spectra to the specified path"""
153+
save_spectra(list(self._spectra), save_file)
154+
155+
if self._embeddings is not None:
156+
embedding_save_name = os.path.splitext(save_file)[0] + "_embeddings.npz"
157+
print(f"Saving embeddings at {embedding_save_name}")
158+
self._embeddings.save(embedding_save_name)
159+
160+
@classmethod
161+
def load(cls, spectrum_file: str) -> "AnnotatedSpectrumSet":
162+
"""Load mass spectra into a AnnotatedSpectrumSet, if embeddings are available they are loaded too"""
163+
spectra = list(load_spectra(spectrum_file))
164+
165+
embedding_file_name = os.path.splitext(spectrum_file)[0] + "_embeddings.npz"
166+
instance = cls.create_spectrum_set(spectra)
167+
if os.path.exists(embedding_file_name):
168+
instance.embeddings = Embeddings.load(embedding_file_name)
169+
return instance

tests/test_benchmarking/test_SpectrumDataSet.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
13
import pytest
24
from ms2query.benchmarking.AnnotatedSpectrumSet import (
35
AnnotatedSpectrumSet,
@@ -76,3 +78,19 @@ def test_add_embeddings():
7678
with pytest.raises(ValueError):
7779
# The spectra don't match so it should raise a valueerror
7880
spectrum_set.embeddings = subset_of_spectra.embeddings
81+
82+
83+
def test_save_and_load(tmp_path):
84+
test_spectra = create_test_spectra(nr_of_inchikeys=3, number_of_spectra_per_inchikey=3)
85+
spectrum_set = AnnotatedSpectrumSet.create_spectrum_set(test_spectra)
86+
file_name = os.path.join(tmp_path, "spectra.mgf")
87+
spectrum_set.save(file_name)
88+
loaded_spectrum_set = spectrum_set.load(file_name)
89+
assert spectrum_set == loaded_spectrum_set
90+
91+
model = ms2deepscore_model()
92+
spectrum_set.add_embeddings(model)
93+
file_name = os.path.join(tmp_path, "spectra_2.mgf")
94+
spectrum_set.save(file_name)
95+
loaded_spectrum_set = spectrum_set.load(file_name)
96+
assert spectrum_set == loaded_spectrum_set

0 commit comments

Comments
 (0)