Skip to content

Commit b22444f

Browse files
authored
fix: save_metadata to support both multiple datasets with many arrays and old set up with no arrays (#239)
## Description - PR to fix #238 Add test to write metadata with no arrays, and when arrays dict contains multiple entries (simulating multiple-dataset set up) ## What problem does this change solve? Previous changes broke support for old set ups where supporting arrays were not defined. ## What issue or task does this change relate to? [<!-- link to Issue Number -->](#238) ## Additional notes ## <!-- Include any additional information, caveats, or considerations that the reviewer should be aware of. --> ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/*** By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md)
1 parent 2a5cdfc commit b22444f

File tree

2 files changed

+55
-11
lines changed

2 files changed

+55
-11
lines changed

src/anemoi/utils/checkpoints.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,13 @@ def load_supporting_arrays(zipf: zipfile.ZipFile, entries: dict) -> dict:
141141

142142
supporting_arrays = {}
143143
for key, entry in entries.items():
144-
supporting_arrays[key] = np.frombuffer(
145-
zipf.read(entry["path"]),
146-
dtype=entry["dtype"],
147-
).reshape(entry["shape"])
144+
if isinstance(entry, dict) and not set(entry.keys()) == set(["path", "shape", "dtype"]):
145+
supporting_arrays[key] = load_supporting_arrays(zipf, entry)
146+
else:
147+
supporting_arrays[key] = np.frombuffer(
148+
zipf.read(entry["path"]),
149+
dtype=entry["dtype"],
150+
).reshape(entry["shape"])
148151
return supporting_arrays
149152

150153

@@ -168,11 +171,14 @@ def _get_supporting_arrays_paths(directory: str, folder: str, supporting_arrays:
168171

169172
def _write_array_to_bytes(array: dict | np.ndarray, name: str, entry: dict, zipf: zipfile.ZipFile) -> None:
170173
"""Write a supporting array to bytes in a zip file."""
174+
if array is None:
175+
return
176+
171177
if isinstance(array, dict):
172178
for sub_name, sub_array in array.items():
173-
_write_array_to_bytes(sub_array, sub_name, entry[sub_name], zipf)
174-
return None
175-
179+
sub_entry = entry.get(sub_name, {})
180+
_write_array_to_bytes(sub_array, sub_name, sub_entry, zipf)
181+
return
176182
LOG.info(
177183
"Saving supporting array `%s` to %s (shape=%s, dtype=%s)",
178184
name,

tests/test_checkpoints.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,21 +187,32 @@ def test_remove_metadata_integration(self, sample_checkpoint):
187187

188188
assert not has_metadata(sample_checkpoint, name="metadata.json")
189189

190-
def test_metadata_with_arrays_roundtrip(self, sample_checkpoint):
190+
@pytest.mark.parametrize(
191+
"arrays",
192+
[
193+
{"test_array": np.array([1, 2, 3, 4, 5])},
194+
{
195+
"test_array": {"test_array": np.array([10, 20, 30])},
196+
},
197+
],
198+
)
199+
def test_metadata_with_arrays_roundtrip(self, sample_checkpoint, arrays):
191200
"""Test complete roundtrip with supporting arrays."""
192201
# First remove existing metadata
193202
remove_metadata(sample_checkpoint, name="metadata.json")
194203

195204
# Add metadata with arrays
196205
metadata = {"version": "1.0", "test": True}
197-
arrays = {"test_array": np.array([1, 2, 3, 4, 5])}
198206

199207
save_metadata(sample_checkpoint, metadata, supporting_arrays=arrays, name="metadata.json")
200208

201209
# Load and verify
202210
loaded_metadata, loaded_arrays = load_metadata(sample_checkpoint, supporting_arrays=True, name="metadata.json")
203211
assert loaded_metadata["test"] is True
204-
np.testing.assert_array_equal(loaded_arrays["test_array"], arrays["test_array"])
212+
if isinstance(loaded_arrays["test_array"], dict):
213+
np.testing.assert_array_equal(loaded_arrays["test_array"]["test_array"], arrays["test_array"]["test_array"])
214+
else:
215+
np.testing.assert_array_equal(loaded_arrays["test_array"], arrays["test_array"])
205216

206217
# Edit with _edit_metadata
207218
def update_callback(file_path):
@@ -217,4 +228,31 @@ def update_callback(file_path):
217228
final_metadata, final_arrays = load_metadata(sample_checkpoint, supporting_arrays=True, name="metadata.json")
218229
assert final_metadata["edited"] is True
219230
assert final_metadata["test"] is True
220-
np.testing.assert_array_equal(final_arrays["test_array"], arrays["test_array"])
231+
if isinstance(loaded_arrays["test_array"], dict):
232+
np.testing.assert_array_equal(final_arrays["test_array"]["test_array"], arrays["test_array"]["test_array"])
233+
else:
234+
np.testing.assert_array_equal(final_arrays["test_array"], arrays["test_array"])
235+
236+
def test_metadata_no_arrays(self, sample_checkpoint):
237+
"""Test without supporting arrays."""
238+
# First remove existing metadata
239+
remove_metadata(sample_checkpoint, name="metadata.json")
240+
241+
# Add metadata with arrays
242+
metadata = {"version": "1.0", "test": True}
243+
244+
save_metadata(sample_checkpoint, metadata, supporting_arrays=None, name="metadata.json")
245+
246+
# Load and verify
247+
loaded_metadata = load_metadata(sample_checkpoint, supporting_arrays=None, name="metadata.json")
248+
assert loaded_metadata["test"] is True
249+
250+
# Edit with _edit_metadata
251+
def update_callback(file_path):
252+
with open(file_path, "r") as f:
253+
data = json.load(f)
254+
data["edited"] = True
255+
with open(file_path, "w") as f:
256+
json.dump(data, f)
257+
258+
_edit_metadata(sample_checkpoint, "metadata.json", update_callback)

0 commit comments

Comments
 (0)