Skip to content

Commit 83c6c7d

Browse files
committed
Added save and load models for embeddings
1 parent f080a5e commit 83c6c7d

File tree

2 files changed

+69
-3
lines changed

2 files changed

+69
-3
lines changed

ms2query/benchmarking/Embeddings.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import json
2+
from pathlib import Path
13
from typing import Sequence
24
import numpy as np
35
import pandas as pd
@@ -17,7 +19,7 @@ def __init__(self, embeddings: np.ndarray, spectrum_hashes: tuple, model_setting
1719
self._spectrum_hash_to_index = {
1820
spectrum_hash: index for index, spectrum_hash in enumerate(self.index_to_spectrum_hash)
1921
}
20-
self._model_settings = model_settings
22+
self.model_settings = model_settings
2123
self._embeddings = embeddings
2224

2325
@classmethod
@@ -39,7 +41,7 @@ def __add__(self, other: "Embeddings") -> "Embeddings":
3941
raise ValueError("There are repeated spectra in the embeddings that are added together")
4042
combined_embeddings = np.vstack([self._embeddings, other._embeddings])
4143
index_to_spectrum_hash = self.index_to_spectrum_hash + other.index_to_spectrum_hash
42-
return Embeddings(combined_embeddings, index_to_spectrum_hash, self._model_settings)
44+
return Embeddings(combined_embeddings, index_to_spectrum_hash, self.model_settings)
4345

4446
def subset_embeddings(self, spectra):
4547
spectrum_hashes = tuple(spectrum.__hash__() for spectrum in spectra)
@@ -58,11 +60,15 @@ def embeddings(self):
5860
def model_settings(self):
5961
return self._model_settings.copy()
6062

63+
@model_settings.setter
64+
def model_settings(self, model_settings):
65+
self._model_settings: dict = _to_json_serializable(model_settings)
66+
6167
def copy(self) -> "Embeddings":
6268
return Embeddings(
6369
embeddings=self._embeddings.copy(),
6470
spectrum_hashes=tuple(self.index_to_spectrum_hash),
65-
model_settings=dict(self._model_settings),
71+
model_settings=dict(self.model_settings),
6672
)
6773

6874
def __eq__(self, other) -> bool:
@@ -76,10 +82,59 @@ def __eq__(self, other) -> bool:
7682
return False
7783
return np.array_equal(self.embeddings, other.embeddings)
7884

85+
def save(self, path: str | Path) -> None:
86+
"""Save embeddings to a .npz file with metadata stored alongside.
87+
88+
Args:
89+
path: File path. A '.npz' extension will be added if not present.
90+
"""
91+
path = Path(path).with_suffix(".npz")
92+
metadata = {
93+
"model_settings": self.model_settings,
94+
"index_to_spectrum_hash": list(self.index_to_spectrum_hash),
95+
}
96+
np.savez_compressed(
97+
path,
98+
embeddings=self._embeddings,
99+
metadata=np.array(json.dumps(metadata)),
100+
)
101+
102+
@classmethod
103+
def load(cls, path: str | Path) -> "Embeddings":
104+
"""Load embeddings from a saved .npz file.
105+
106+
Args:
107+
path: Path to the saved .npz file.
108+
"""
109+
path = Path(path).with_suffix(".npz")
110+
with np.load(path, allow_pickle=False) as data:
111+
embeddings = data["embeddings"]
112+
metadata = json.loads(data["metadata"].item())
113+
return cls(
114+
embeddings=embeddings,
115+
spectrum_hashes=tuple(metadata["index_to_spectrum_hash"]),
116+
model_settings=metadata["model_settings"],
117+
)
118+
79119

80120
def calculate_ms2deepscore_df(query_embeddings: Embeddings, library_embeddings: Embeddings):
81121
"""Returns a DF, where the indexes and column labels are the spectrum hashes"""
82122
ms2deepscores = cosine_similarity_matrix(query_embeddings.embeddings, library_embeddings.embeddings)
83123
return pd.DataFrame(
84124
ms2deepscores, index=query_embeddings.index_to_spectrum_hash, columns=library_embeddings.index_to_spectrum_hash
85125
)
126+
127+
128+
def _to_json_serializable(obj):
129+
"""Changes a dict to be json sericalizable, so it is the same when loaded"""
130+
if isinstance(obj, dict):
131+
return {key: _to_json_serializable(value) for key, value in obj.items()}
132+
if isinstance(obj, (list, tuple)):
133+
return [_to_json_serializable(item) for item in obj]
134+
if isinstance(obj, np.integer):
135+
return int(obj)
136+
if isinstance(obj, np.floating):
137+
return float(obj)
138+
if isinstance(obj, np.ndarray):
139+
return obj.tolist()
140+
return obj

tests/test_benchmarking/test_embeddings.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import pytest
23
from ms2query.benchmarking.Embeddings import Embeddings, calculate_ms2deepscore_df
34
from tests.helper_functions import create_test_spectra, get_library_and_test_spectra_not_identical, ms2deepscore_model
@@ -34,3 +35,13 @@ def test_add_embeddings():
3435
def test_calculate_ms2deepscore_df():
3536
library_spectra, query_spectra = get_library_and_test_spectra_not_identical()
3637
calculate_ms2deepscore_df(library_spectra.embeddings, query_spectra.embeddings)
38+
39+
40+
def test_save_and_load(tmp_path):
41+
test_spectra = create_test_spectra()
42+
model = ms2deepscore_model()
43+
embeddings = Embeddings.create_from_spectra(test_spectra[:4], model)
44+
file_name = os.path.join(tmp_path, "embeddings.npz")
45+
embeddings.save(file_name)
46+
loaded_embeddings = Embeddings.load(file_name)
47+
assert embeddings == loaded_embeddings

0 commit comments

Comments
 (0)