From 2c6e3a4aab605411e27eb881e3bc1dde41a7ff92 Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Fri, 3 Oct 2025 14:21:10 -0700 Subject: [PATCH] Testing if every array needs metadata in the array metadata store PiperOrigin-RevId: 814829164 --- pathwaysutils/_initialize.py | 2 + pathwaysutils/persistence/orbax_handler.py | 104 ++++++++++++++++++--- 2 files changed, 94 insertions(+), 12 deletions(-) diff --git a/pathwaysutils/_initialize.py b/pathwaysutils/_initialize.py index a8df2a0..27476c9 100644 --- a/pathwaysutils/_initialize.py +++ b/pathwaysutils/_initialize.py @@ -18,6 +18,7 @@ import os import jax +from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib from pathwaysutils import profiling from pathwaysutils import proxy_backend from pathwaysutils.persistence import orbax_handler @@ -94,6 +95,7 @@ def initialize() -> None: if _is_persistence_enabled(): orbax_handler.register_pathways_handlers( timeout=datetime.timedelta(hours=1), + array_metadata_store=array_metadata_store_lib.Store(), ) # Turn off JAX compilation cache because Pathways handles its own diff --git a/pathwaysutils/persistence/orbax_handler.py b/pathwaysutils/persistence/orbax_handler.py index 9a35ca8..5fe9698 100644 --- a/pathwaysutils/persistence/orbax_handler.py +++ b/pathwaysutils/persistence/orbax_handler.py @@ -14,16 +14,18 @@ """TypeHandlers supporting Pathways backend.""" import collections -from collections.abc import Sequence +from collections.abc import Coroutine, Sequence import concurrent.futures import datetime import functools import logging -import typing +from typing import Any, cast import jax from orbax.checkpoint import future from orbax.checkpoint import type_handlers +from orbax.checkpoint._src.metadata import array_metadata as array_metadata_lib +from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib from pathwaysutils.persistence import helper @@ -33,6 +35,7 @@ SaveArgs = type_handlers.SaveArgs RestoreArgs = type_handlers.RestoreArgs ArrayRestoreArgs = type_handlers.ArrayRestoreArgs +ArrayMetadata = array_metadata_lib.ArrayMetadata def extract_parent_dir_and_name( @@ -51,6 +54,7 @@ def __init__( self, timeout: datetime.timedelta | None = None, use_ocdbt: bool = False, + array_metadata_store: array_metadata_store_lib.Store | None = None, ): """Orbax array handler for Pathways on Cloud with Persistence API. @@ -58,6 +62,8 @@ def __init__( timeout: Duration indicating the timeout for reading and writing arrays. Default is 1 hour. use_ocdbt: allows using Tensorstore OCDBT driver. + array_metadata_store: An optional store for writing and reading array + metadata. Only required for saving new-style jax random keys. """ if timeout is None: timeout = datetime.timedelta(hours=1) @@ -65,12 +71,16 @@ def __init__( if use_ocdbt: raise ValueError("OCDBT not supported for Pathways.") - super().__init__() + super().__init__(array_metadata_store=array_metadata_store) async def _background_serialize( self, futures_results: Sequence[concurrent.futures.Future[None]], + metadata_coroutine: Coroutine[Any, Any, None] | None = None, ) -> None: + if metadata_coroutine: + await metadata_coroutine + for future_result in futures_results: future_result.result() @@ -86,21 +96,60 @@ async def serialize( values: Sequence[jax.Array], infos: Sequence[ParamInfo], args: Sequence[SaveArgs] | None = None, - ) -> Sequence[future.Future]: + ) -> list[future.Future]: """Uses Pathways Persistence API to serialize a jax array.""" type_handlers.check_input_arguments(values, infos, args) if any([arg.dtype is not None for arg in args]): raise ValueError("Casting during save not supported for Pathways.") + array_metadatas = [] + arrays = [] + for v, info, arg in zip(values, infos, args): + ext_metadata = None + if jax.dtypes.issubdtype(v.dtype, jax.dtypes.prng_key): + # a JAX random key + v = jax.random.key_data(v) + ext_metadata = { + array_metadata_lib.RANDOM_KEY_IMPL: str(jax.random.key_impl(v)) + } + + array_metadatas.append( + ArrayMetadata( + param_name=info.name, + shape=v.shape, + dtype=(arg.dtype if arg is not None else v.dtype), + write_shape=getattr(v, "local_shape", v.shape), + chunk_shape=getattr(v, "local_shape", v.shape), + use_ocdbt=False, + use_zarr3=False, + ext_metadata=ext_metadata, + ) + ) + arrays.append(v) + + metadata_coroutine = None + if array_metadatas: + if self._array_metadata_store is None: + raise ValueError( + "Array metadata store is not set with a checkpoint that requires" + f" it. Array metadata: {array_metadatas}" + ) + + metadata_coroutine = self._array_metadata_store.write( + checkpoint_dir=infos[0].parent_dir, + array_metadatas=array_metadatas, + process_index=0, + ) + self._wait_for_directory_creation_signals() locations, names = extract_parent_dir_and_name(infos) f = functools.partial(helper.write_one_array, timeout=self.timeout) - futures_results = list(map(f, locations, names, values)) + futures_results = list(map(f, locations, names, arrays)) return [ future.CommitFutureAwaitingContractedSignals( - self._background_serialize(futures_results), + self._background_serialize(futures_results, metadata_coroutine), name="cloud_pathways_array_handler", ) ] @@ -109,7 +158,7 @@ async def deserialize( self, infos: Sequence[ParamInfo], args: Sequence[RestoreArgs] | None = None, - ) -> Sequence[jax.Array]: + ) -> list[jax.Array]: """Uses Pathways Persistence API to deserialize a jax array.""" if args is None: raise ValueError("Must provide ArrayRestoreArgs to restore as jax.Array.") @@ -128,7 +177,7 @@ async def deserialize( "To restore jax.Array, provide ArrayRestoreArgs; found" f" {type(arg).__name__}" ) - arg = typing.cast(ArrayRestoreArgs, arg) + arg = cast(ArrayRestoreArgs, arg) if arg.sharding is None and (arg.mesh is None or arg.mesh_axes is None): raise ValueError( "Sharding of jax.Array cannot be None. Provide `mesh`" @@ -143,7 +192,7 @@ async def deserialize( else: if not isinstance(arg.sharding, jax.sharding.NamedSharding): raise ValueError("Pathways only supports jax.sharding.NamedSharding.") - sharding = typing.cast(jax.sharding.NamedSharding, arg.sharding) + sharding = cast(jax.sharding.NamedSharding, arg.sharding) global_meshes.append(sharding.mesh) mesh_axes.append(sharding.spec) shardings.append(sharding) @@ -163,13 +212,30 @@ async def deserialize( ] dtypes = [m.dtype if d is None else d for m, d in zip(metadatas, dtypes)] + array_metadatas_cache = {} + if self._array_metadata_store is not None: + if array_metadatas := await self._array_metadata_store.read( + checkpoint_dir=infos[0].parent_dir, + process_index=0, + ): + if not isinstance(array_metadatas, list): + raise ValueError( + "Array metadata store returned unexpected result:" + f" {array_metadatas}" + ) + + array_metadatas_cache = { + array_metadata.param_name: array_metadata + for array_metadata in array_metadatas + } + # Group inputs by global_mesh so that we can perform batched Array # construction for each global_mesh. inputs_by_global_mesh = collections.defaultdict(list) for i, global_mesh in enumerate(global_meshes): inputs_by_global_mesh[global_mesh].append(i) - results = [None] * len(infos) + results = cast(list[jax.Array], [None] * len(infos)) for global_mesh, idxs in inputs_by_global_mesh.items(): grouped_infos = [infos[idx] for idx in idxs] @@ -188,13 +254,26 @@ async def deserialize( ) # each persistence call is awaited serially. read_future.result() - for idx, arr in zip(idxs, grouped_arrays): + for idx, info, arr in zip(idxs, grouped_infos, grouped_arrays): + if meta := array_metadatas_cache.get(info.name): + assert isinstance( + meta, array_metadata_lib.SerializedArrayMetadata + ), f"Expecting SerializedArrayMetadata but got {type(meta)}." + assert isinstance(meta.ext_metadata, dict), ( + "Expecting ext_metadata to be a dict but got" + f" {type(meta.ext_metadata)}." + ) + + if impl := meta.ext_metadata.get(array_metadata_lib.RANDOM_KEY_IMPL): + arr = jax.random.wrap_key_data(arr, impl=impl) results[idx] = arr - return results # pytype: disable=bad-return-type + + return results def register_pathways_handlers( timeout: datetime.timedelta | None = None, + array_metadata_store: array_metadata_store_lib.Store | None = None, ): """Function that must be called before saving or restoring with Pathways.""" logger.debug( @@ -204,6 +283,7 @@ def register_pathways_handlers( jax.Array, CloudPathwaysArrayHandler( timeout=timeout, + array_metadata_store=array_metadata_store, ), override=True, )