1414"""TypeHandlers supporting Pathways backend."""
1515
1616import collections
17- from collections .abc import Sequence
17+ from collections .abc import Coroutine , Sequence
1818import concurrent .futures
1919import datetime
2020import functools
2121import logging
22- import typing
22+ from typing import Any , cast
2323
2424import jax
2525from orbax .checkpoint import future
2626from orbax .checkpoint import type_handlers
27+ from orbax .checkpoint ._src .metadata import array_metadata as array_metadata_lib
28+ from orbax .checkpoint ._src .metadata import array_metadata_store as array_metadata_store_lib
2729from pathwaysutils .persistence import helper
2830
2931
3335SaveArgs = type_handlers .SaveArgs
3436RestoreArgs = type_handlers .RestoreArgs
3537ArrayRestoreArgs = type_handlers .ArrayRestoreArgs
38+ ArrayMetadata = array_metadata_lib .ArrayMetadata
3639
3740
3841def extract_parent_dir_and_name (
@@ -49,25 +52,34 @@ class CloudPathwaysArrayHandler(type_handlers.ArrayHandler):
4952
5053 def __init__ (
5154 self ,
52- read_timeout : datetime .timedelta | None = None ,
55+ timeout : datetime .timedelta | None = None ,
5356 use_ocdbt : bool = False ,
57+ array_metadata_store : array_metadata_store_lib .Store | None = None ,
5458 ):
55- """Constructor .
59+ """Orbax array handler for Pathways on Cloud with Persistence API .
5660
5761 Args:
58- read_timeout : Duration indicating the timeout for reading arrays
62+ timeout : Duration indicating the timeout for reading and writing arrays
5963 use_ocdbt: allows using Tensorstore OCDBT driver.
64+ array_metadata_store: An optional store for writing and reading array
65+ metadata. Only required for saving new-style jax random keys.
6066 """
61- self ._read_timeout = read_timeout
67+ if timeout is None :
68+ timeout = datetime .timedelta (hours = 1 )
69+ self .timeout = timeout
6270
6371 if use_ocdbt :
6472 raise ValueError ("OCDBT not supported for Pathways." )
65- super ().__init__ ()
73+ super ().__init__ (array_metadata_store = array_metadata_store )
6674
6775 async def _background_serialize (
6876 self ,
6977 futures_results : Sequence [concurrent .futures .Future [None ]],
78+ metadata_coroutine : Coroutine [Any , Any , None ] | None = None ,
7079 ) -> None :
80+ if metadata_coroutine :
81+ await metadata_coroutine
82+
7183 for future_result in futures_results :
7284 future_result .result ()
7385
@@ -90,14 +102,53 @@ async def serialize(
90102 if any ([arg .dtype is not None for arg in args ]):
91103 raise ValueError ("Casting during save not supported for Pathways." )
92104
105+ array_metadatas = []
106+ arrays = []
107+ for v , info , arg in zip (values , infos , args ):
108+ if jax .dtypes .issubdtype (v .dtype , jax .dtypes .prng_key ):
109+ # a JAX random key
110+ arrays .append (jax .random .key_data (v ))
111+ array_metadatas .append (
112+ ArrayMetadata (
113+ param_name = info .name ,
114+ shape = v .shape ,
115+ dtype = (arg .dtype if arg is not None else v .dtype ),
116+ write_shape = getattr (v , "local_shape" , v .shape ),
117+ chunk_shape = getattr (v , "local_shape" , v .shape ),
118+ use_ocdbt = False ,
119+ use_zarr3 = False ,
120+ ext_metadata = {
121+ array_metadata_lib .RANDOM_KEY_IMPL : str (
122+ jax .random .key_impl (v )
123+ )
124+ },
125+ )
126+ )
127+ else :
128+ arrays .append (v )
129+
130+ metadata_coroutine = None
131+ if array_metadatas :
132+ if self ._array_metadata_store is None :
133+ raise ValueError (
134+ "Array metadata store is not set with a checkpoint that requires"
135+ f" it. Array metadata: { array_metadatas } "
136+ )
137+
138+ metadata_coroutine = self ._array_metadata_store .write (
139+ checkpoint_dir = infos [0 ].parent_dir ,
140+ array_metadatas = array_metadatas ,
141+ process_index = 0 ,
142+ )
143+
93144 self ._wait_for_directory_creation_signals ()
94145 locations , names = extract_parent_dir_and_name (infos )
95- f = functools .partial (helper .write_one_array , timeout = self ._read_timeout )
96- futures_results = list (map (f , locations , names , values ))
146+ f = functools .partial (helper .write_one_array , timeout = self .timeout )
147+ futures_results = list (map (f , locations , names , arrays ))
97148
98149 return [
99150 future .CommitFutureAwaitingContractedSignals (
100- self ._background_serialize (futures_results ),
151+ self ._background_serialize (futures_results , metadata_coroutine ),
101152 name = "cloud_pathways_array_handler" ,
102153 )
103154 ]
@@ -106,7 +157,7 @@ async def deserialize(
106157 self ,
107158 infos : Sequence [ParamInfo ],
108159 args : Sequence [RestoreArgs ] | None = None ,
109- ) -> Sequence [jax .Array ]:
160+ ) -> list [jax .Array ]:
110161 """Uses Pathways Persistence API to deserialize a jax array."""
111162 if args is None :
112163 raise ValueError ("Must provide ArrayRestoreArgs to restore as jax.Array." )
@@ -125,7 +176,7 @@ async def deserialize(
125176 "To restore jax.Array, provide ArrayRestoreArgs; found"
126177 f" { type (arg ).__name__ } "
127178 )
128- arg = typing . cast (ArrayRestoreArgs , arg )
179+ arg = cast (ArrayRestoreArgs , arg )
129180 if arg .sharding is None and (arg .mesh is None or arg .mesh_axes is None ):
130181 raise ValueError (
131182 "Sharding of jax.Array cannot be None. Provide `mesh`"
@@ -140,7 +191,7 @@ async def deserialize(
140191 else :
141192 if not isinstance (arg .sharding , jax .sharding .NamedSharding ):
142193 raise ValueError ("Pathways only supports jax.sharding.NamedSharding." )
143- sharding = typing . cast (jax .sharding .NamedSharding , arg .sharding )
194+ sharding = cast (jax .sharding .NamedSharding , arg .sharding )
144195 global_meshes .append (sharding .mesh )
145196 mesh_axes .append (sharding .spec )
146197 shardings .append (sharding )
@@ -160,13 +211,30 @@ async def deserialize(
160211 ]
161212 dtypes = [m .dtype if d is None else d for m , d in zip (metadatas , dtypes )]
162213
214+ array_metadatas_cache = {}
215+ if self ._array_metadata_store is not None :
216+ array_metadatas = await self ._array_metadata_store .read (
217+ checkpoint_dir = infos [0 ].parent_dir ,
218+ process_index = 0 ,
219+ )
220+ if not isinstance (array_metadatas , list ):
221+ raise ValueError (
222+ "Array metadata store returned unexpected result:"
223+ f" { array_metadatas } "
224+ )
225+
226+ array_metadatas_cache = {
227+ array_metadata .param_name : array_metadata
228+ for array_metadata in array_metadatas
229+ }
230+
163231 # Group inputs by global_mesh so that we can perform batched Array
164232 # construction for each global_mesh.
165233 inputs_by_global_mesh = collections .defaultdict (list )
166234 for i , global_mesh in enumerate (global_meshes ):
167235 inputs_by_global_mesh [global_mesh ].append (i )
168236
169- results = [ None ] * len (infos )
237+ results = cast ( list [ jax . Array ], [ None ] * len (infos ) )
170238
171239 for global_mesh , idxs in inputs_by_global_mesh .items ():
172240 grouped_infos = [infos [idx ] for idx in idxs ]
@@ -181,17 +249,30 @@ async def deserialize(
181249 grouped_global_shapes ,
182250 grouped_shardings ,
183251 global_mesh .devices ,
184- timeout = self ._read_timeout ,
252+ timeout = self .timeout ,
185253 )
186254 # each persistence call is awaited serially.
187255 read_future .result ()
188- for idx , arr in zip (idxs , grouped_arrays ):
256+ for idx , info , arr in zip (idxs , grouped_infos , grouped_arrays ):
257+ if meta := array_metadatas_cache .get (info .name ):
258+ assert isinstance (
259+ meta , array_metadata_lib .SerializedArrayMetadata
260+ ), f"Expecting SerializedArrayMetadata but got { type (meta )} ."
261+ assert isinstance (meta .ext_metadata , dict ), (
262+ "Expecting ext_metadata to be a dict but got"
263+ f" { type (meta .ext_metadata )} ."
264+ )
265+
266+ if impl := meta .ext_metadata .get (array_metadata_lib .RANDOM_KEY_IMPL ):
267+ arr = jax .random .wrap_key_data (arr , impl = impl )
189268 results [idx ] = arr
190- return results # pytype: disable=bad-return-type
269+
270+ return results
191271
192272
193273def register_pathways_handlers (
194- read_timeout : datetime .timedelta | None = None ,
274+ timeout : datetime .timedelta | None = None ,
275+ array_metadata_store : array_metadata_store_lib .Store | None = None ,
195276):
196277 """Function that must be called before saving or restoring with Pathways."""
197278 logger .debug (
@@ -200,7 +281,8 @@ def register_pathways_handlers(
200281 type_handlers .register_type_handler (
201282 jax .Array ,
202283 CloudPathwaysArrayHandler (
203- read_timeout = read_timeout ,
284+ timeout = timeout ,
285+ array_metadata_store = array_metadata_store ,
204286 ),
205287 override = True ,
206288 )
0 commit comments