Skip to content

Commit f080a5e

Browse files
committed
Added an embedding setter with checks to AnnotatedSpectrumSet
1 parent a11bd46 commit f080a5e

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

ms2query/benchmarking/AnnotatedSpectrumSet.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(
1919
self.spectrum_indices_per_inchikey: dict[str, tuple[int, ...]] = {
2020
key: tuple(values) for key, values in spectrum_indices_per_inchikey.items()
2121
}
22-
self._embeddings = embeddings
22+
self.embeddings = embeddings
2323

2424
@classmethod
2525
def create_spectrum_set(cls, spectra: Sequence[Spectrum]) -> "AnnotatedSpectrumSet":
@@ -99,6 +99,17 @@ def embeddings(self) -> Embeddings:
9999
raise ValueError("First run the 'add_embeddings' method")
100100
return self._embeddings
101101

102+
@embeddings.setter
103+
def embeddings(self, embeddings: Optional[Embeddings]):
104+
if embeddings is None:
105+
self._embeddings = embeddings
106+
return
107+
if not embeddings.index_to_spectrum_hash == tuple(spectrum.__hash__() for spectrum in self.spectra):
108+
raise ValueError(
109+
"The embeddings spectrum hashes don't match the spectrum hashes, make sure you use matching embeddings"
110+
)
111+
self._embeddings = embeddings
112+
102113
@property
103114
def inchikeys(self):
104115
return tuple(self.spectrum_indices_per_inchikey.keys())

tests/test_benchmarking/test_SpectrumDataSet.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pytest
12
from ms2query.benchmarking.AnnotatedSpectrumSet import (
23
AnnotatedSpectrumSet,
34
)
@@ -60,3 +61,18 @@ def test_subset_on_metadata():
6061
spectrum_set.add_embeddings(model)
6162

6263
assert correct_subsetted_set == spectrum_set.subset_spectra_on_metadata("ionmode", set(["positive"]))
64+
65+
66+
def test_add_embeddings():
67+
test_spectra = create_test_spectra(nr_of_inchikeys=3, number_of_spectra_per_inchikey=3)
68+
69+
spectrum_set = AnnotatedSpectrumSet.create_spectrum_set(test_spectra)
70+
71+
subset_of_spectra = AnnotatedSpectrumSet.create_spectrum_set(test_spectra[:5])
72+
spectrum_set = AnnotatedSpectrumSet.create_spectrum_set(test_spectra)
73+
# with added embededings
74+
model = ms2deepscore_model()
75+
subset_of_spectra.add_embeddings(model)
76+
with pytest.raises(ValueError):
77+
# The spectra don't match so it should raise a valueerror
78+
spectrum_set.embeddings = subset_of_spectra.embeddings

0 commit comments

Comments
 (0)