1- # (C) Copyright 2024 Anemoi contributors.
1+ # (C) Copyright 2024- Anemoi contributors.
22#
33# This software is licensed under the terms of the Apache Licence Version 2.0
44# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
1919import zipfile
2020from collections .abc import Callable
2121from tempfile import TemporaryDirectory
22+ from typing import Literal
23+ from typing import overload
2224
2325import numpy as np
2426import tqdm
2527
2628LOG = logging .getLogger (__name__ )
2729
28- DEFAULT_NAME = "ai-models .json"
30+ DEFAULT_NAME = "anemoi .json"
2931DEFAULT_FOLDER = "anemoi-metadata"
3032
33+ DEPRECATED_NAME = "ai-models.json"
34+
3135
3236def has_metadata (path : str , * , name : str = DEFAULT_NAME ) -> bool :
3337 """Check if a checkpoint file has a metadata file.
@@ -45,14 +49,11 @@ def has_metadata(path: str, *, name: str = DEFAULT_NAME) -> bool:
4549 True if the metadata file is found
4650 """
4751 with zipfile .ZipFile (path , "r" ) as f :
48- for b in f .namelist ():
49- if os .path .basename (b ) == name :
50- return True
51- return False
52+ return any (os .path .basename (b ) == name for b in f .namelist ())
5253
5354
54- def metadata_root (path : str , * , name : str = DEFAULT_NAME ) -> str :
55- """Get the root directory of the metadata file.
55+ def get_metadata_path (path : str , * , name : str = DEFAULT_NAME ) -> str :
56+ """Get the full path of the metadata file in the checkpoint .
5657
5758 Parameters
5859 ----------
@@ -64,21 +65,50 @@ def metadata_root(path: str, *, name: str = DEFAULT_NAME) -> str:
6465 Returns
6566 -------
6667 str
67- The root directory of the metadata file
68+ The full path of the metadata file in the zip archive
6869
6970 Raises
7071 ------
71- ValueError
72+ FileNotFoundError
7273 If the metadata file is not found
74+ ValueError
75+ If multiple metadata files are found
7376 """
7477 with zipfile .ZipFile (path , "r" ) as f :
75- for b in f .namelist ():
76- if os .path .basename (b ) == name :
77- return os .path .dirname (b )
78- raise ValueError (f"Could not find '{ name } ' in { path } ." )
78+ metadata_file = list (filter (lambda b : os .path .basename (b ) == name , f .namelist ()))
79+ if len (metadata_file ) == 0 :
80+ raise FileNotFoundError (f"Could not find '{ name } ' in { path } ." )
81+ if len (metadata_file ) > 1 :
82+ raise ValueError (f"Found two or more '{ name } ' in { path } ." )
83+ return metadata_file [0 ]
84+
85+
86+ def _support_metadata_name_deprecation (path : str , name : str ) -> str :
87+ """Support deprecated metadata name, automatically switching if needed and logging a warning."""
88+ if name == DEFAULT_NAME and not has_metadata (path , name = DEFAULT_NAME ):
89+ if has_metadata (path , name = DEPRECATED_NAME ):
90+ LOG .warning (
91+ "The metadata file '%s' is deprecated. New versions of checkpoints will write to '%s' instead." ,
92+ DEPRECATED_NAME ,
93+ DEFAULT_NAME ,
94+ )
95+ name = DEPRECATED_NAME
96+ return name
97+
98+
99+ # TODO: Refactor this function to reduce complexity
100+ @overload
101+ def load_metadata (path : str , * , supporting_arrays : Literal [False ] = False , name : str = DEFAULT_NAME ) -> dict : # type: ignore[reportOverlappingOverload]
102+ ...
103+
104+
105+ @overload
106+ def load_metadata (
107+ path : str , * , supporting_arrays : Literal [True ] = True , name : str = DEFAULT_NAME
108+ ) -> tuple [dict , dict ]: ...
79109
80110
81- def load_metadata (path : str , * , supporting_arrays : bool = False , name : str = DEFAULT_NAME ) -> dict :
111+ def load_metadata (path : str , * , supporting_arrays : bool = False , name : str = DEFAULT_NAME ) -> dict | tuple [ dict , dict ] :
82112 """Load metadata from a checkpoint file.
83113
84114 Parameters
@@ -102,24 +132,15 @@ def load_metadata(path: str, *, supporting_arrays: bool = False, name: str = DEF
102132 ValueError
103133 If the metadata file is not found
104134 """
105- with zipfile .ZipFile (path , "r" ) as f :
106- metadata = None
107- for b in f .namelist ():
108- if os .path .basename (b ) == name :
109- if metadata is not None :
110- raise ValueError (f"Found two or more '{ name } ' in { path } ." )
111- metadata = b
112-
113- if metadata is not None :
114- with zipfile .ZipFile (path , "r" ) as f :
115- metadata = json .load (f .open (metadata , "r" ))
116- if supporting_arrays :
117- arrays = load_supporting_arrays (f , metadata .get ("supporting_arrays_paths" , {}))
118- return metadata , arrays
135+ name = _support_metadata_name_deprecation (path , name )
136+ metadata = get_metadata_path (path , name = name )
119137
120- return metadata
121- else :
122- raise ValueError (f"Could not find '{ name } ' in { path } ." )
138+ with zipfile .ZipFile (path , "r" ) as f :
139+ metadata = json .load (f .open (metadata , "r" ))
140+ if supporting_arrays :
141+ arrays = load_supporting_arrays (f , metadata .get ("supporting_arrays_paths" , {}))
142+ return metadata , arrays
143+ return metadata
123144
124145
125146def load_supporting_arrays (zipf : zipfile .ZipFile , entries : dict ) -> dict :
@@ -190,7 +211,12 @@ def _write_array_to_bytes(array: dict | np.ndarray, name: str, entry: dict, zipf
190211
191212
192213def save_metadata (
193- path : str , metadata : dict , * , supporting_arrays : dict = None , name : str = DEFAULT_NAME , folder : str = DEFAULT_FOLDER
214+ path : str ,
215+ metadata : dict ,
216+ * ,
217+ supporting_arrays : dict | None = None ,
218+ name : str = DEFAULT_NAME ,
219+ folder : str = DEFAULT_FOLDER ,
194220) -> None :
195221 """Save metadata to a checkpoint file.
196222
@@ -200,7 +226,7 @@ def save_metadata(
200226 The path to the checkpoint file
201227 metadata : dict
202228 A JSON serializable object
203- supporting_arrays : dict, optional
229+ supporting_arrays : dict | None , optional
204230 A dictionary of supporting NumPy arrays
205231 name : str, optional
206232 The name of the metadata file in the zip archive
@@ -257,20 +283,14 @@ def _edit_metadata(path: str, name: str, callback: Callable, supporting_arrays:
257283 """
258284 new_path = f"{ path } .anemoi-edit-{ time .time ()} -{ os .getpid ()} .tmp"
259285
260- with zipfile .ZipFile (path , "r" ) as source_zip :
261- file_list = source_zip .namelist ()
286+ target_file = get_metadata_path (path , name = name )
287+ if target_file is None :
288+ raise FileNotFoundError (f"Could not find '{ name } ' in { path } " )
262289
263- # Find the target file and its directory
264- target_file = None
265- directory = None
266- for file_path in file_list :
267- if os .path .basename (file_path ) == name :
268- target_file = file_path
269- directory = os .path .dirname (file_path )
270- break
290+ directory = os .path .dirname (target_file )
271291
272- if target_file is None :
273- raise ValueError ( f"Could not find ' { name } ' in { path } " )
292+ with zipfile . ZipFile ( path , "r" ) as source_zip :
293+ file_list = source_zip . namelist ( )
274294
275295 # Calculate total files for progress bar
276296 total_files = len (file_list )
@@ -313,7 +333,9 @@ def _edit_metadata(path: str, name: str, callback: Callable, supporting_arrays:
313333 LOG .info ("Updated metadata in %s" , path )
314334
315335
316- def replace_metadata (path : str , metadata : dict , supporting_arrays : dict = None , * , name : str = DEFAULT_NAME ) -> None :
336+ def replace_metadata (
337+ path : str , metadata : dict , supporting_arrays : dict | None = None , * , name : str = DEFAULT_NAME
338+ ) -> None :
317339 """Replace metadata in a checkpoint file.
318340
319341 Parameters
@@ -337,6 +359,7 @@ def callback(full):
337359 with open (full , "w" ) as f :
338360 json .dump (metadata , f )
339361
362+ name = _support_metadata_name_deprecation (path , name )
340363 return _edit_metadata (path , name , callback , supporting_arrays )
341364
342365
@@ -350,6 +373,7 @@ def remove_metadata(path: str, *, name: str = DEFAULT_NAME) -> None:
350373 name : str, optional
351374 The name of the metadata file in the zip archive
352375 """
376+ name = _support_metadata_name_deprecation (path , name )
353377
354378 def callback (full ):
355379 os .remove (full )
0 commit comments