Skip to content

Commit b5b1437

Browse files
authored
feat: Deprecate ai-models.json in favour of anemoi.json (#247)
## Description The default metadata name was `ai-models.json`, this PR deprecates this in favour of `anemoi.json` with a warning logged to the console. Logging was done as this should be a seemless change for users with old checkpoints working fine.
1 parent 25bb224 commit b5b1437

File tree

2 files changed

+81
-49
lines changed

2 files changed

+81
-49
lines changed

src/anemoi/utils/checkpoints.py

Lines changed: 71 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# (C) Copyright 2024 Anemoi contributors.
1+
# (C) Copyright 2024- Anemoi contributors.
22
#
33
# This software is licensed under the terms of the Apache Licence Version 2.0
44
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
@@ -19,15 +19,19 @@
1919
import zipfile
2020
from collections.abc import Callable
2121
from tempfile import TemporaryDirectory
22+
from typing import Literal
23+
from typing import overload
2224

2325
import numpy as np
2426
import tqdm
2527

2628
LOG = logging.getLogger(__name__)
2729

28-
DEFAULT_NAME = "ai-models.json"
30+
DEFAULT_NAME = "anemoi.json"
2931
DEFAULT_FOLDER = "anemoi-metadata"
3032

33+
DEPRECATED_NAME = "ai-models.json"
34+
3135

3236
def has_metadata(path: str, *, name: str = DEFAULT_NAME) -> bool:
3337
"""Check if a checkpoint file has a metadata file.
@@ -45,14 +49,11 @@ def has_metadata(path: str, *, name: str = DEFAULT_NAME) -> bool:
4549
True if the metadata file is found
4650
"""
4751
with zipfile.ZipFile(path, "r") as f:
48-
for b in f.namelist():
49-
if os.path.basename(b) == name:
50-
return True
51-
return False
52+
return any(os.path.basename(b) == name for b in f.namelist())
5253

5354

54-
def metadata_root(path: str, *, name: str = DEFAULT_NAME) -> str:
55-
"""Get the root directory of the metadata file.
55+
def get_metadata_path(path: str, *, name: str = DEFAULT_NAME) -> str:
56+
"""Get the full path of the metadata file in the checkpoint.
5657
5758
Parameters
5859
----------
@@ -64,21 +65,50 @@ def metadata_root(path: str, *, name: str = DEFAULT_NAME) -> str:
6465
Returns
6566
-------
6667
str
67-
The root directory of the metadata file
68+
The full path of the metadata file in the zip archive
6869
6970
Raises
7071
------
71-
ValueError
72+
FileNotFoundError
7273
If the metadata file is not found
74+
ValueError
75+
If multiple metadata files are found
7376
"""
7477
with zipfile.ZipFile(path, "r") as f:
75-
for b in f.namelist():
76-
if os.path.basename(b) == name:
77-
return os.path.dirname(b)
78-
raise ValueError(f"Could not find '{name}' in {path}.")
78+
metadata_file = list(filter(lambda b: os.path.basename(b) == name, f.namelist()))
79+
if len(metadata_file) == 0:
80+
raise FileNotFoundError(f"Could not find '{name}' in {path}.")
81+
if len(metadata_file) > 1:
82+
raise ValueError(f"Found two or more '{name}' in {path}.")
83+
return metadata_file[0]
84+
85+
86+
def _support_metadata_name_deprecation(path: str, name: str) -> str:
87+
"""Support deprecated metadata name, automatically switching if needed and logging a warning."""
88+
if name == DEFAULT_NAME and not has_metadata(path, name=DEFAULT_NAME):
89+
if has_metadata(path, name=DEPRECATED_NAME):
90+
LOG.warning(
91+
"The metadata file '%s' is deprecated. New versions of checkpoints will write to '%s' instead.",
92+
DEPRECATED_NAME,
93+
DEFAULT_NAME,
94+
)
95+
name = DEPRECATED_NAME
96+
return name
97+
98+
99+
# TODO: Refactor this function to reduce complexity
100+
@overload
101+
def load_metadata(path: str, *, supporting_arrays: Literal[False] = False, name: str = DEFAULT_NAME) -> dict: # type: ignore[reportOverlappingOverload]
102+
...
103+
104+
105+
@overload
106+
def load_metadata(
107+
path: str, *, supporting_arrays: Literal[True] = True, name: str = DEFAULT_NAME
108+
) -> tuple[dict, dict]: ...
79109

80110

81-
def load_metadata(path: str, *, supporting_arrays: bool = False, name: str = DEFAULT_NAME) -> dict:
111+
def load_metadata(path: str, *, supporting_arrays: bool = False, name: str = DEFAULT_NAME) -> dict | tuple[dict, dict]:
82112
"""Load metadata from a checkpoint file.
83113
84114
Parameters
@@ -102,24 +132,15 @@ def load_metadata(path: str, *, supporting_arrays: bool = False, name: str = DEF
102132
ValueError
103133
If the metadata file is not found
104134
"""
105-
with zipfile.ZipFile(path, "r") as f:
106-
metadata = None
107-
for b in f.namelist():
108-
if os.path.basename(b) == name:
109-
if metadata is not None:
110-
raise ValueError(f"Found two or more '{name}' in {path}.")
111-
metadata = b
112-
113-
if metadata is not None:
114-
with zipfile.ZipFile(path, "r") as f:
115-
metadata = json.load(f.open(metadata, "r"))
116-
if supporting_arrays:
117-
arrays = load_supporting_arrays(f, metadata.get("supporting_arrays_paths", {}))
118-
return metadata, arrays
135+
name = _support_metadata_name_deprecation(path, name)
136+
metadata = get_metadata_path(path, name=name)
119137

120-
return metadata
121-
else:
122-
raise ValueError(f"Could not find '{name}' in {path}.")
138+
with zipfile.ZipFile(path, "r") as f:
139+
metadata = json.load(f.open(metadata, "r"))
140+
if supporting_arrays:
141+
arrays = load_supporting_arrays(f, metadata.get("supporting_arrays_paths", {}))
142+
return metadata, arrays
143+
return metadata
123144

124145

125146
def load_supporting_arrays(zipf: zipfile.ZipFile, entries: dict) -> dict:
@@ -190,7 +211,12 @@ def _write_array_to_bytes(array: dict | np.ndarray, name: str, entry: dict, zipf
190211

191212

192213
def save_metadata(
193-
path: str, metadata: dict, *, supporting_arrays: dict = None, name: str = DEFAULT_NAME, folder: str = DEFAULT_FOLDER
214+
path: str,
215+
metadata: dict,
216+
*,
217+
supporting_arrays: dict | None = None,
218+
name: str = DEFAULT_NAME,
219+
folder: str = DEFAULT_FOLDER,
194220
) -> None:
195221
"""Save metadata to a checkpoint file.
196222
@@ -200,7 +226,7 @@ def save_metadata(
200226
The path to the checkpoint file
201227
metadata : dict
202228
A JSON serializable object
203-
supporting_arrays : dict, optional
229+
supporting_arrays : dict | None, optional
204230
A dictionary of supporting NumPy arrays
205231
name : str, optional
206232
The name of the metadata file in the zip archive
@@ -257,20 +283,14 @@ def _edit_metadata(path: str, name: str, callback: Callable, supporting_arrays:
257283
"""
258284
new_path = f"{path}.anemoi-edit-{time.time()}-{os.getpid()}.tmp"
259285

260-
with zipfile.ZipFile(path, "r") as source_zip:
261-
file_list = source_zip.namelist()
286+
target_file = get_metadata_path(path, name=name)
287+
if target_file is None:
288+
raise FileNotFoundError(f"Could not find '{name}' in {path}")
262289

263-
# Find the target file and its directory
264-
target_file = None
265-
directory = None
266-
for file_path in file_list:
267-
if os.path.basename(file_path) == name:
268-
target_file = file_path
269-
directory = os.path.dirname(file_path)
270-
break
290+
directory = os.path.dirname(target_file)
271291

272-
if target_file is None:
273-
raise ValueError(f"Could not find '{name}' in {path}")
292+
with zipfile.ZipFile(path, "r") as source_zip:
293+
file_list = source_zip.namelist()
274294

275295
# Calculate total files for progress bar
276296
total_files = len(file_list)
@@ -313,7 +333,9 @@ def _edit_metadata(path: str, name: str, callback: Callable, supporting_arrays:
313333
LOG.info("Updated metadata in %s", path)
314334

315335

316-
def replace_metadata(path: str, metadata: dict, supporting_arrays: dict = None, *, name: str = DEFAULT_NAME) -> None:
336+
def replace_metadata(
337+
path: str, metadata: dict, supporting_arrays: dict | None = None, *, name: str = DEFAULT_NAME
338+
) -> None:
317339
"""Replace metadata in a checkpoint file.
318340
319341
Parameters
@@ -337,6 +359,7 @@ def callback(full):
337359
with open(full, "w") as f:
338360
json.dump(metadata, f)
339361

362+
name = _support_metadata_name_deprecation(path, name)
340363
return _edit_metadata(path, name, callback, supporting_arrays)
341364

342365

@@ -350,6 +373,7 @@ def remove_metadata(path: str, *, name: str = DEFAULT_NAME) -> None:
350373
name : str, optional
351374
The name of the metadata file in the zip archive
352375
"""
376+
name = _support_metadata_name_deprecation(path, name)
353377

354378
def callback(full):
355379
os.remove(full)

tests/test_checkpoints.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import numpy as np
77
import pytest
88

9+
from anemoi.utils.checkpoints import DEFAULT_NAME
10+
from anemoi.utils.checkpoints import DEPRECATED_NAME
911
from anemoi.utils.checkpoints import _edit_metadata
1012
from anemoi.utils.checkpoints import has_metadata
1113
from anemoi.utils.checkpoints import load_metadata
@@ -149,7 +151,7 @@ def test_edit_metadata_file_not_found(self, sample_checkpoint):
149151
def dummy_callback(file_path):
150152
pass
151153

152-
with pytest.raises(ValueError, match="Could not find 'nonexistent.json'"):
154+
with pytest.raises(FileNotFoundError, match="Could not find 'nonexistent.json'"):
153155
_edit_metadata(sample_checkpoint, "nonexistent.json", dummy_callback)
154156

155157
def test_edit_metadata_callback_exception_handling(self, sample_checkpoint):
@@ -244,7 +246,7 @@ def test_metadata_no_arrays(self, sample_checkpoint):
244246
save_metadata(sample_checkpoint, metadata, supporting_arrays=None, name="metadata.json")
245247

246248
# Load and verify
247-
loaded_metadata = load_metadata(sample_checkpoint, supporting_arrays=None, name="metadata.json")
249+
loaded_metadata = load_metadata(sample_checkpoint, supporting_arrays=False, name="metadata.json")
248250
assert loaded_metadata["test"] is True
249251

250252
# Edit with _edit_metadata
@@ -256,3 +258,9 @@ def update_callback(file_path):
256258
json.dump(data, f)
257259

258260
_edit_metadata(sample_checkpoint, "metadata.json", update_callback)
261+
262+
def test_automatic_deprecation_handling(self, sample_checkpoint: str):
263+
"""Test that deprecated metadata name is handled automatically."""
264+
save_metadata(sample_checkpoint, {"version": "1.0", "test": "Deprecation"}, name=DEPRECATED_NAME)
265+
metadata = load_metadata(sample_checkpoint, name=DEFAULT_NAME) # Should auto-switch and log warning
266+
assert metadata["test"] == "Deprecation"

0 commit comments

Comments
 (0)