Skip to content

Commit 39ddce9

Browse files
authored
fix: saving model outputs (#66)
* fix: saving jax arrays * fix: loading score * refactor: all file openings to follow same pattern * fix: saving jax arrays
1 parent f24c5c2 commit 39ddce9

File tree

3 files changed

+16
-9
lines changed

3 files changed

+16
-9
lines changed

src/mlipaudit/io.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ def write_benchmark_result_to_disk(
4949
_output_dir.mkdir(exist_ok=True, parents=True)
5050
(_output_dir / benchmark_name).mkdir(exist_ok=True)
5151

52-
with (_output_dir / benchmark_name / RESULT_FILENAME).open("w") as json_file:
52+
with open(
53+
_output_dir / benchmark_name / RESULT_FILENAME, mode="w", encoding="utf-8"
54+
) as json_file:
5355
json_as_str = json.loads(result.model_dump_json()) # type: ignore
5456
json.dump(json_as_str, json_file, indent=2)
5557

@@ -71,7 +73,9 @@ def load_benchmark_result_from_disk(
7173
_results_dir = Path(results_dir)
7274
benchmark_subdir = _results_dir / benchmark_class.name
7375

74-
with (benchmark_subdir / RESULT_FILENAME).open("r", encoding="utf-8") as json_file:
76+
with open(
77+
benchmark_subdir / RESULT_FILENAME, mode="r", encoding="utf-8"
78+
) as json_file:
7579
json_data = json.load(json_file)
7680

7781
return benchmark_class.result_class(**json_data) # type: ignore
@@ -142,8 +146,8 @@ def write_scores_to_disk(
142146
"""
143147
_output_dir = Path(output_dir)
144148
_output_dir.mkdir(exist_ok=True, parents=True)
145-
with (_output_dir / SCORE_FILENAME).open("w") as json_file:
146-
json.dump(scores, json_file, indent=2)
149+
with open(_output_dir / SCORE_FILENAME, "w", encoding="utf-8") as f:
150+
json.dump(scores, f, indent=2)
147151

148152

149153
def load_score_from_disk(
@@ -160,8 +164,9 @@ def load_score_from_disk(
160164
A dictionary of scores where the keys are the
161165
benchmark names.
162166
"""
163-
with (Path(output_dir) / SCORE_FILENAME).open("w") as json_file:
164-
scores = json.load(json_file)
167+
with open(Path(output_dir) / SCORE_FILENAME, mode="r", encoding="utf-8") as f:
168+
scores = json.load(f)
169+
165170
return scores
166171

167172

@@ -213,7 +218,7 @@ def write_model_output_to_disk(
213218
json_path = Path(tmpdir) / MODEL_OUTPUT_JSON_FILENAME
214219
arrays_path = Path(tmpdir) / MODEL_OUTPUT_ARRAYS_FILENAME
215220

216-
with json_path.open("w") as json_file:
221+
with open(json_path, "w", encoding="utf-8") as json_file:
217222
json.dump(data, json_file)
218223

219224
np.savez(arrays_path, **arrays)

src/mlipaudit/io_helpers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from dataclasses import fields, is_dataclass
1616
from typing import Any, ClassVar, Protocol, Type, TypeVar, get_args
1717

18+
import jax.numpy as jnp
1819
import numpy as np
1920
import pydantic
2021

@@ -63,7 +64,7 @@ def recurse(value):
6364
return {k: recurse(v) for k, v in value.items()}
6465
elif isinstance(value, (list, tuple)):
6566
return [recurse(v) for v in value]
66-
elif isinstance(value, np.ndarray):
67+
elif isinstance(value, (np.ndarray, jnp.ndarray)):
6768
key = f"np_{counter[0]}"
6869
arrays[key] = value
6970
counter[0] += 1

tests/test_io.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from dataclasses import fields
1818
from pathlib import Path
1919

20+
import jax.numpy as jnp
2021
import numpy as np
2122
from mlip.simulation import SimulationState
2223

@@ -78,7 +79,7 @@ def test_model_outputs_io_works(
7879
# First, set up two different simulation states
7980
dummy_sim_state_1 = SimulationState(
8081
atomic_numbers=np.array([1, 8, 6, 1]),
81-
positions=np.ones((7, 4, 3)),
82+
positions=jnp.ones((7, 4, 3)),
8283
forces=np.random.rand(7, 4, 3),
8384
velocities=np.zeros((7, 4, 3)),
8485
temperature=np.full((7,), 1.23),

0 commit comments

Comments
 (0)