|
| 1 | +import os |
1 | 2 | from collections import defaultdict |
2 | 3 | from typing import Iterable, List, Mapping, Optional, Sequence |
3 | 4 | from matchms import Spectrum |
| 5 | +from matchms.exporting import save_spectra |
| 6 | +from matchms.importing import load_spectra |
4 | 7 | from ms2deepscore.models import SiameseSpectralModel |
5 | 8 | from tqdm import tqdm |
6 | 9 | from ms2query.benchmarking.Embeddings import Embeddings |
@@ -144,3 +147,23 @@ def __str__(self): |
144 | 147 | with_embeddings = "with embeddings" |
145 | 148 |
|
146 | 149 | return f"{len(self)} spectra for {len(self.inchikeys)} inchikeys {with_embeddings}" |
| 150 | + |
| 151 | + def save(self, save_file: str) -> None: |
| 152 | + """Save spectra to the specified path""" |
| 153 | + save_spectra(list(self._spectra), save_file) |
| 154 | + |
| 155 | + if self._embeddings is not None: |
| 156 | + embedding_save_name = os.path.splitext(save_file)[0] + "_embeddings.npz" |
| 157 | + print(f"Saving embeddings at {embedding_save_name}") |
| 158 | + self._embeddings.save(embedding_save_name) |
| 159 | + |
| 160 | + @classmethod |
| 161 | + def load(cls, spectrum_file: str) -> "AnnotatedSpectrumSet": |
| 162 | + """Load mass spectra into a AnnotatedSpectrumSet, if embeddings are available they are loaded too""" |
| 163 | + spectra = list(load_spectra(spectrum_file)) |
| 164 | + |
| 165 | + embedding_file_name = os.path.splitext(spectrum_file)[0] + "_embeddings.npz" |
| 166 | + instance = cls.create_spectrum_set(spectra) |
| 167 | + if os.path.exists(embedding_file_name): |
| 168 | + instance.embeddings = Embeddings.load(embedding_file_name) |
| 169 | + return instance |
0 commit comments