1+ import json
2+ from pathlib import Path
13from typing import Sequence
24import numpy as np
35import pandas as pd
@@ -17,7 +19,7 @@ def __init__(self, embeddings: np.ndarray, spectrum_hashes: tuple, model_setting
1719 self ._spectrum_hash_to_index = {
1820 spectrum_hash : index for index , spectrum_hash in enumerate (self .index_to_spectrum_hash )
1921 }
20- self ._model_settings = model_settings
22+ self .model_settings = model_settings
2123 self ._embeddings = embeddings
2224
2325 @classmethod
@@ -39,7 +41,7 @@ def __add__(self, other: "Embeddings") -> "Embeddings":
3941 raise ValueError ("There are repeated spectra in the embeddings that are added together" )
4042 combined_embeddings = np .vstack ([self ._embeddings , other ._embeddings ])
4143 index_to_spectrum_hash = self .index_to_spectrum_hash + other .index_to_spectrum_hash
42- return Embeddings (combined_embeddings , index_to_spectrum_hash , self ._model_settings )
44+ return Embeddings (combined_embeddings , index_to_spectrum_hash , self .model_settings )
4345
4446 def subset_embeddings (self , spectra ):
4547 spectrum_hashes = tuple (spectrum .__hash__ () for spectrum in spectra )
@@ -58,11 +60,15 @@ def embeddings(self):
5860 def model_settings (self ):
5961 return self ._model_settings .copy ()
6062
63+ @model_settings .setter
64+ def model_settings (self , model_settings ):
65+ self ._model_settings : dict = _to_json_serializable (model_settings )
66+
6167 def copy (self ) -> "Embeddings" :
6268 return Embeddings (
6369 embeddings = self ._embeddings .copy (),
6470 spectrum_hashes = tuple (self .index_to_spectrum_hash ),
65- model_settings = dict (self ._model_settings ),
71+ model_settings = dict (self .model_settings ),
6672 )
6773
6874 def __eq__ (self , other ) -> bool :
@@ -76,10 +82,59 @@ def __eq__(self, other) -> bool:
7682 return False
7783 return np .array_equal (self .embeddings , other .embeddings )
7884
85+ def save (self , path : str | Path ) -> None :
86+ """Save embeddings to a .npz file with metadata stored alongside.
87+
88+ Args:
89+ path: File path. A '.npz' extension will be added if not present.
90+ """
91+ path = Path (path ).with_suffix (".npz" )
92+ metadata = {
93+ "model_settings" : self .model_settings ,
94+ "index_to_spectrum_hash" : list (self .index_to_spectrum_hash ),
95+ }
96+ np .savez_compressed (
97+ path ,
98+ embeddings = self ._embeddings ,
99+ metadata = np .array (json .dumps (metadata )),
100+ )
101+
102+ @classmethod
103+ def load (cls , path : str | Path ) -> "Embeddings" :
104+ """Load embeddings from a saved .npz file.
105+
106+ Args:
107+ path: Path to the saved .npz file.
108+ """
109+ path = Path (path ).with_suffix (".npz" )
110+ with np .load (path , allow_pickle = False ) as data :
111+ embeddings = data ["embeddings" ]
112+ metadata = json .loads (data ["metadata" ].item ())
113+ return cls (
114+ embeddings = embeddings ,
115+ spectrum_hashes = tuple (metadata ["index_to_spectrum_hash" ]),
116+ model_settings = metadata ["model_settings" ],
117+ )
118+
79119
80120def calculate_ms2deepscore_df (query_embeddings : Embeddings , library_embeddings : Embeddings ):
81121 """Returns a DF, where the indexes and column labels are the spectrum hashes"""
82122 ms2deepscores = cosine_similarity_matrix (query_embeddings .embeddings , library_embeddings .embeddings )
83123 return pd .DataFrame (
84124 ms2deepscores , index = query_embeddings .index_to_spectrum_hash , columns = library_embeddings .index_to_spectrum_hash
85125 )
126+
127+
128+ def _to_json_serializable (obj ):
129+ """Changes a dict to be json sericalizable, so it is the same when loaded"""
130+ if isinstance (obj , dict ):
131+ return {key : _to_json_serializable (value ) for key , value in obj .items ()}
132+ if isinstance (obj , (list , tuple )):
133+ return [_to_json_serializable (item ) for item in obj ]
134+ if isinstance (obj , np .integer ):
135+ return int (obj )
136+ if isinstance (obj , np .floating ):
137+ return float (obj )
138+ if isinstance (obj , np .ndarray ):
139+ return obj .tolist ()
140+ return obj
0 commit comments