@@ -30,16 +30,16 @@ def create_from_spectra(cls, spectra: Sequence[Spectrum], model: SiameseSpectral
3030 embeddings : np .ndarray = compute_embedding_array (model , spectra ) # type: ignore
3131 return cls (embeddings , index_to_spectrum_hash , model_settings )
3232
33- @classmethod
34- def combine_embeddings (cls , embeddings_1 : "Embeddings" , embeddings_2 : "Embeddings" ) -> "Embeddings" :
35- if embeddings_1 .model_settings != embeddings_2 .model_settings :
33+ def __add__ (self , other : "Embeddings" ) -> "Embeddings" :
34+ if not isinstance (other , Embeddings ):
35+ return NotImplemented
36+ if self .model_settings != other .model_settings :
3637 raise ValueError ("Model settings of merged embeddings do not match" )
37- if not set (embeddings_1 .index_to_spectrum_hash ).isdisjoint (embeddings_2 .index_to_spectrum_hash ):
38- # todo allow this to happen, but remove repeating ones and check that they are the same.
38+ if not set (self .index_to_spectrum_hash ).isdisjoint (other .index_to_spectrum_hash ):
3939 raise ValueError ("There are repeated spectra in the embeddings that are added together" )
40- combined_embeddings = np .vstack ([embeddings_1 . embeddings , embeddings_2 . embeddings ])
41- index_to_spectrum_hash = embeddings_1 .index_to_spectrum_hash + embeddings_2 .index_to_spectrum_hash
42- return cls (combined_embeddings , index_to_spectrum_hash , embeddings_1 . model_settings )
40+ combined_embeddings = np .vstack ([self . _embeddings , other . _embeddings ])
41+ index_to_spectrum_hash = self .index_to_spectrum_hash + other .index_to_spectrum_hash
42+ return Embeddings (combined_embeddings , index_to_spectrum_hash , self . _model_settings )
4343
4444 def subset_embeddings (self , spectra ):
4545 spectrum_hashes = tuple (spectrum .__hash__ () for spectrum in spectra )
0 commit comments