Skip to content

Commit 8ab6ebb

Browse files
lukebaumanncopybara-github
authored andcommitted
Support jax.random.PRNGKey serialization in Pathways Orbax handler.
This change allows `CloudPathwaysArrayHandler` to correctly save and restore `jax.random.PRNGKey` objects by extracting and wrapping the key data, and storing metadata about the key implementation using an `ArrayMetadataStore`. This change introduces a dependency on Orbax's internal API. PiperOrigin-RevId: 813440677
1 parent 2fa0623 commit 8ab6ebb

File tree

2 files changed

+106
-20
lines changed

2 files changed

+106
-20
lines changed

pathwaysutils/_initialize.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import os
1919

2020
import jax
21+
from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib
2122
from pathwaysutils import profiling
2223
from pathwaysutils import proxy_backend
2324
from pathwaysutils.persistence import orbax_handler
@@ -92,7 +93,10 @@ def initialize() -> None:
9293
profiling.monkey_patch_jax()
9394
# TODO: b/365549911 - Remove when OCDBT-compatible
9495
if _is_persistence_enabled():
95-
orbax_handler.register_pathways_handlers(datetime.timedelta(hours=1))
96+
orbax_handler.register_pathways_handlers(
97+
timeout=datetime.timedelta(hours=1),
98+
array_metadata_store=array_metadata_store_lib.Store(),
99+
)
96100

97101
# Turn off JAX compilation cache because Pathways handles its own
98102
# compilation cache.

pathwaysutils/persistence/orbax_handler.py

Lines changed: 101 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,18 @@
1414
"""TypeHandlers supporting Pathways backend."""
1515

1616
import collections
17-
from collections.abc import Sequence
17+
from collections.abc import Coroutine, Sequence
1818
import concurrent.futures
1919
import datetime
2020
import functools
2121
import logging
22-
import typing
22+
from typing import Any, cast
2323

2424
import jax
2525
from orbax.checkpoint import future
2626
from orbax.checkpoint import type_handlers
27+
from orbax.checkpoint._src.metadata import array_metadata as array_metadata_lib
28+
from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib
2729
from pathwaysutils.persistence import helper
2830

2931

@@ -33,6 +35,7 @@
3335
SaveArgs = type_handlers.SaveArgs
3436
RestoreArgs = type_handlers.RestoreArgs
3537
ArrayRestoreArgs = type_handlers.ArrayRestoreArgs
38+
ArrayMetadata = array_metadata_lib.ArrayMetadata
3639

3740

3841
def extract_parent_dir_and_name(
@@ -49,25 +52,34 @@ class CloudPathwaysArrayHandler(type_handlers.ArrayHandler):
4952

5053
def __init__(
5154
self,
52-
read_timeout: datetime.timedelta | None = None,
55+
timeout: datetime.timedelta | None = None,
5356
use_ocdbt: bool = False,
57+
array_metadata_store: array_metadata_store_lib.Store | None = None,
5458
):
55-
"""Constructor.
59+
"""Orbax array handler for Pathways on Cloud with Persistence API.
5660
5761
Args:
58-
read_timeout: Duration indicating the timeout for reading arrays
62+
timeout: Duration indicating the timeout for reading and writing arrays
5963
use_ocdbt: allows using Tensorstore OCDBT driver.
64+
array_metadata_store: An optional store for writing and reading array
65+
metadata. Only required for saving new-style jax random keys.
6066
"""
61-
self._read_timeout = read_timeout
67+
if timeout is None:
68+
timeout = datetime.timedelta(hours=1)
69+
self.timeout = timeout
6270

6371
if use_ocdbt:
6472
raise ValueError("OCDBT not supported for Pathways.")
65-
super().__init__()
73+
super().__init__(array_metadata_store=array_metadata_store)
6674

6775
async def _background_serialize(
6876
self,
6977
futures_results: Sequence[concurrent.futures.Future[None]],
78+
metadata_coroutine: Coroutine[Any, Any, None] | None = None,
7079
) -> None:
80+
if metadata_coroutine:
81+
await metadata_coroutine
82+
7183
for future_result in futures_results:
7284
future_result.result()
7385

@@ -90,14 +102,53 @@ async def serialize(
90102
if any([arg.dtype is not None for arg in args]):
91103
raise ValueError("Casting during save not supported for Pathways.")
92104

105+
array_metadatas = []
106+
arrays = []
107+
for v, info, arg in zip(values, infos, args):
108+
if jax.dtypes.issubdtype(v.dtype, jax.dtypes.prng_key):
109+
# a JAX random key
110+
arrays.append(jax.random.key_data(v))
111+
array_metadatas.append(
112+
ArrayMetadata(
113+
param_name=info.name,
114+
shape=v.shape,
115+
dtype=(arg.dtype if arg is not None else v.dtype),
116+
write_shape=getattr(v, "local_shape", v.shape),
117+
chunk_shape=getattr(v, "local_shape", v.shape),
118+
use_ocdbt=False,
119+
use_zarr3=False,
120+
ext_metadata={
121+
array_metadata_lib.RANDOM_KEY_IMPL: str(
122+
jax.random.key_impl(v)
123+
)
124+
},
125+
)
126+
)
127+
else:
128+
arrays.append(v)
129+
130+
metadata_coroutine = None
131+
if array_metadatas:
132+
if self._array_metadata_store is None:
133+
raise ValueError(
134+
"Array metadata store is not set with a checkpoint that requires"
135+
f" it. Array metadata: {array_metadatas}"
136+
)
137+
138+
metadata_coroutine = self._array_metadata_store.write(
139+
checkpoint_dir=infos[0].parent_dir,
140+
array_metadatas=array_metadatas,
141+
process_index=0,
142+
)
143+
93144
self._wait_for_directory_creation_signals()
94145
locations, names = extract_parent_dir_and_name(infos)
95-
f = functools.partial(helper.write_one_array, timeout=self._read_timeout)
96-
futures_results = list(map(f, locations, names, values))
146+
f = functools.partial(helper.write_one_array, timeout=self.timeout)
147+
futures_results = list(map(f, locations, names, arrays))
97148

98149
return [
99150
future.CommitFutureAwaitingContractedSignals(
100-
self._background_serialize(futures_results),
151+
self._background_serialize(futures_results, metadata_coroutine),
101152
name="cloud_pathways_array_handler",
102153
)
103154
]
@@ -106,7 +157,7 @@ async def deserialize(
106157
self,
107158
infos: Sequence[ParamInfo],
108159
args: Sequence[RestoreArgs] | None = None,
109-
) -> Sequence[jax.Array]:
160+
) -> list[jax.Array]:
110161
"""Uses Pathways Persistence API to deserialize a jax array."""
111162
if args is None:
112163
raise ValueError("Must provide ArrayRestoreArgs to restore as jax.Array.")
@@ -125,7 +176,7 @@ async def deserialize(
125176
"To restore jax.Array, provide ArrayRestoreArgs; found"
126177
f" {type(arg).__name__}"
127178
)
128-
arg = typing.cast(ArrayRestoreArgs, arg)
179+
arg = cast(ArrayRestoreArgs, arg)
129180
if arg.sharding is None and (arg.mesh is None or arg.mesh_axes is None):
130181
raise ValueError(
131182
"Sharding of jax.Array cannot be None. Provide `mesh`"
@@ -140,7 +191,7 @@ async def deserialize(
140191
else:
141192
if not isinstance(arg.sharding, jax.sharding.NamedSharding):
142193
raise ValueError("Pathways only supports jax.sharding.NamedSharding.")
143-
sharding = typing.cast(jax.sharding.NamedSharding, arg.sharding)
194+
sharding = cast(jax.sharding.NamedSharding, arg.sharding)
144195
global_meshes.append(sharding.mesh)
145196
mesh_axes.append(sharding.spec)
146197
shardings.append(sharding)
@@ -160,13 +211,30 @@ async def deserialize(
160211
]
161212
dtypes = [m.dtype if d is None else d for m, d in zip(metadatas, dtypes)]
162213

214+
array_metadatas_cache = {}
215+
if self._array_metadata_store is not None:
216+
array_metadatas = await self._array_metadata_store.read(
217+
checkpoint_dir=infos[0].parent_dir,
218+
process_index=0,
219+
)
220+
if not isinstance(array_metadatas, list):
221+
raise ValueError(
222+
"Array metadata store returned unexpected result:"
223+
f" {array_metadatas}"
224+
)
225+
226+
array_metadatas_cache = {
227+
array_metadata.param_name: array_metadata
228+
for array_metadata in array_metadatas
229+
}
230+
163231
# Group inputs by global_mesh so that we can perform batched Array
164232
# construction for each global_mesh.
165233
inputs_by_global_mesh = collections.defaultdict(list)
166234
for i, global_mesh in enumerate(global_meshes):
167235
inputs_by_global_mesh[global_mesh].append(i)
168236

169-
results = [None] * len(infos)
237+
results = cast(list[jax.Array], [None] * len(infos))
170238

171239
for global_mesh, idxs in inputs_by_global_mesh.items():
172240
grouped_infos = [infos[idx] for idx in idxs]
@@ -181,17 +249,30 @@ async def deserialize(
181249
grouped_global_shapes,
182250
grouped_shardings,
183251
global_mesh.devices,
184-
timeout=self._read_timeout,
252+
timeout=self.timeout,
185253
)
186254
# each persistence call is awaited serially.
187255
read_future.result()
188-
for idx, arr in zip(idxs, grouped_arrays):
256+
for idx, info, arr in zip(idxs, grouped_infos, grouped_arrays):
257+
if meta := array_metadatas_cache.get(info.name):
258+
assert isinstance(
259+
meta, array_metadata_lib.SerializedArrayMetadata
260+
), f"Expecting SerializedArrayMetadata but got {type(meta)}."
261+
assert isinstance(meta.ext_metadata, dict), (
262+
"Expecting ext_metadata to be a dict but got"
263+
f" {type(meta.ext_metadata)}."
264+
)
265+
266+
if impl := meta.ext_metadata.get(array_metadata_lib.RANDOM_KEY_IMPL):
267+
arr = jax.random.wrap_key_data(arr, impl=impl)
189268
results[idx] = arr
190-
return results # pytype: disable=bad-return-type
269+
270+
return results
191271

192272

193273
def register_pathways_handlers(
194-
read_timeout: datetime.timedelta | None = None,
274+
timeout: datetime.timedelta | None = None,
275+
array_metadata_store: array_metadata_store_lib.Store | None = None,
195276
):
196277
"""Function that must be called before saving or restoring with Pathways."""
197278
logger.debug(
@@ -200,7 +281,8 @@ def register_pathways_handlers(
200281
type_handlers.register_type_handler(
201282
jax.Array,
202283
CloudPathwaysArrayHandler(
203-
read_timeout=read_timeout,
284+
timeout=timeout,
285+
array_metadata_store=array_metadata_store,
204286
),
205287
override=True,
206288
)

0 commit comments

Comments
 (0)