Skip to content

Commit a11bd46

Browse files
committed
Replace combine embeddings with __add__
1 parent 41e4c4b commit a11bd46

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

ms2query/benchmarking/Embeddings.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,16 @@ def create_from_spectra(cls, spectra: Sequence[Spectrum], model: SiameseSpectral
3030
embeddings: np.ndarray = compute_embedding_array(model, spectra) # type: ignore
3131
return cls(embeddings, index_to_spectrum_hash, model_settings)
3232

33-
@classmethod
34-
def combine_embeddings(cls, embeddings_1: "Embeddings", embeddings_2: "Embeddings") -> "Embeddings":
35-
if embeddings_1.model_settings != embeddings_2.model_settings:
33+
def __add__(self, other: "Embeddings") -> "Embeddings":
34+
if not isinstance(other, Embeddings):
35+
return NotImplemented
36+
if self.model_settings != other.model_settings:
3637
raise ValueError("Model settings of merged embeddings do not match")
37-
if not set(embeddings_1.index_to_spectrum_hash).isdisjoint(embeddings_2.index_to_spectrum_hash):
38-
# todo allow this to happen, but remove repeating ones and check that they are the same.
38+
if not set(self.index_to_spectrum_hash).isdisjoint(other.index_to_spectrum_hash):
3939
raise ValueError("There are repeated spectra in the embeddings that are added together")
40-
combined_embeddings = np.vstack([embeddings_1.embeddings, embeddings_2.embeddings])
41-
index_to_spectrum_hash = embeddings_1.index_to_spectrum_hash + embeddings_2.index_to_spectrum_hash
42-
return cls(combined_embeddings, index_to_spectrum_hash, embeddings_1.model_settings)
40+
combined_embeddings = np.vstack([self._embeddings, other._embeddings])
41+
index_to_spectrum_hash = self.index_to_spectrum_hash + other.index_to_spectrum_hash
42+
return Embeddings(combined_embeddings, index_to_spectrum_hash, self._model_settings)
4343

4444
def subset_embeddings(self, spectra):
4545
spectrum_hashes = tuple(spectrum.__hash__() for spectrum in spectra)

tests/test_benchmarking/test_embeddings.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,16 @@ def test_subset_embeddings():
1919
embeddings.subset_embeddings(test_spectra)
2020

2121

22-
def test_combine_embeddings():
22+
def test_add_embeddings():
2323
test_spectra = create_test_spectra()
2424
model = ms2deepscore_model()
2525
embeddings_1 = Embeddings.create_from_spectra(test_spectra[:4], model)
2626
embeddings_2 = Embeddings.create_from_spectra(test_spectra[4:], model)
2727
correct_combined_embeddings = Embeddings.create_from_spectra(test_spectra, model)
28-
combined_embeddings = Embeddings.combine_embeddings(embeddings_1, embeddings_2)
28+
combined_embeddings = embeddings_1 + embeddings_2
2929
assert combined_embeddings == correct_combined_embeddings
3030
with pytest.raises(ValueError):
31-
Embeddings.combine_embeddings(embeddings_1, embeddings_1)
31+
_ = embeddings_1 + embeddings_1
3232

3333

3434
def test_calculate_ms2deepscore_df():

0 commit comments

Comments
 (0)