Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pathwaysutils/_initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os

import jax
from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib
from pathwaysutils import profiling
from pathwaysutils import proxy_backend
from pathwaysutils.persistence import orbax_handler
Expand Down Expand Up @@ -94,6 +95,7 @@ def initialize() -> None:
if _is_persistence_enabled():
orbax_handler.register_pathways_handlers(
timeout=datetime.timedelta(hours=1),
array_metadata_store=array_metadata_store_lib.Store(),
)

# Turn off JAX compilation cache because Pathways handles its own
Expand Down
104 changes: 92 additions & 12 deletions pathwaysutils/persistence/orbax_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,18 @@
"""TypeHandlers supporting Pathways backend."""

import collections
from collections.abc import Sequence
from collections.abc import Coroutine, Sequence
import concurrent.futures
import datetime
import functools
import logging
import typing
from typing import Any, cast

import jax
from orbax.checkpoint import future
from orbax.checkpoint import type_handlers
from orbax.checkpoint._src.metadata import array_metadata as array_metadata_lib
from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib
from pathwaysutils.persistence import helper


Expand All @@ -33,6 +35,7 @@
SaveArgs = type_handlers.SaveArgs
RestoreArgs = type_handlers.RestoreArgs
ArrayRestoreArgs = type_handlers.ArrayRestoreArgs
ArrayMetadata = array_metadata_lib.ArrayMetadata


def extract_parent_dir_and_name(
Expand All @@ -51,26 +54,33 @@ def __init__(
self,
timeout: datetime.timedelta | None = None,
use_ocdbt: bool = False,
array_metadata_store: array_metadata_store_lib.Store | None = None,
):
"""Orbax array handler for Pathways on Cloud with Persistence API.

Args:
timeout: Duration indicating the timeout for reading and writing arrays.
Default is 1 hour.
use_ocdbt: allows using Tensorstore OCDBT driver.
array_metadata_store: An optional store for writing and reading array
metadata. Only required for saving new-style jax random keys.
"""
if timeout is None:
timeout = datetime.timedelta(hours=1)
self.timeout = timeout

if use_ocdbt:
raise ValueError("OCDBT not supported for Pathways.")
super().__init__()
super().__init__(array_metadata_store=array_metadata_store)

async def _background_serialize(
self,
futures_results: Sequence[concurrent.futures.Future[None]],
metadata_coroutine: Coroutine[Any, Any, None] | None = None,
) -> None:
if metadata_coroutine:
await metadata_coroutine

for future_result in futures_results:
future_result.result()

Expand All @@ -86,21 +96,60 @@ async def serialize(
values: Sequence[jax.Array],
infos: Sequence[ParamInfo],
args: Sequence[SaveArgs] | None = None,
) -> Sequence[future.Future]:
) -> list[future.Future]:
"""Uses Pathways Persistence API to serialize a jax array."""
type_handlers.check_input_arguments(values, infos, args)

if any([arg.dtype is not None for arg in args]):
raise ValueError("Casting during save not supported for Pathways.")

array_metadatas = []
arrays = []
for v, info, arg in zip(values, infos, args):
ext_metadata = None
if jax.dtypes.issubdtype(v.dtype, jax.dtypes.prng_key):
# a JAX random key
v = jax.random.key_data(v)
ext_metadata = {
array_metadata_lib.RANDOM_KEY_IMPL: str(jax.random.key_impl(v))
}

array_metadatas.append(
ArrayMetadata(
param_name=info.name,
shape=v.shape,
dtype=(arg.dtype if arg is not None else v.dtype),
write_shape=getattr(v, "local_shape", v.shape),
chunk_shape=getattr(v, "local_shape", v.shape),
use_ocdbt=False,
use_zarr3=False,
ext_metadata=ext_metadata,
)
)
arrays.append(v)

metadata_coroutine = None
if array_metadatas:
if self._array_metadata_store is None:
raise ValueError(
"Array metadata store is not set with a checkpoint that requires"
f" it. Array metadata: {array_metadatas}"
)

metadata_coroutine = self._array_metadata_store.write(
checkpoint_dir=infos[0].parent_dir,
array_metadatas=array_metadatas,
process_index=0,
)

self._wait_for_directory_creation_signals()
locations, names = extract_parent_dir_and_name(infos)
f = functools.partial(helper.write_one_array, timeout=self.timeout)
futures_results = list(map(f, locations, names, values))
futures_results = list(map(f, locations, names, arrays))

return [
future.CommitFutureAwaitingContractedSignals(
self._background_serialize(futures_results),
self._background_serialize(futures_results, metadata_coroutine),
name="cloud_pathways_array_handler",
)
]
Expand All @@ -109,7 +158,7 @@ async def deserialize(
self,
infos: Sequence[ParamInfo],
args: Sequence[RestoreArgs] | None = None,
) -> Sequence[jax.Array]:
) -> list[jax.Array]:
"""Uses Pathways Persistence API to deserialize a jax array."""
if args is None:
raise ValueError("Must provide ArrayRestoreArgs to restore as jax.Array.")
Expand All @@ -128,7 +177,7 @@ async def deserialize(
"To restore jax.Array, provide ArrayRestoreArgs; found"
f" {type(arg).__name__}"
)
arg = typing.cast(ArrayRestoreArgs, arg)
arg = cast(ArrayRestoreArgs, arg)
if arg.sharding is None and (arg.mesh is None or arg.mesh_axes is None):
raise ValueError(
"Sharding of jax.Array cannot be None. Provide `mesh`"
Expand All @@ -143,7 +192,7 @@ async def deserialize(
else:
if not isinstance(arg.sharding, jax.sharding.NamedSharding):
raise ValueError("Pathways only supports jax.sharding.NamedSharding.")
sharding = typing.cast(jax.sharding.NamedSharding, arg.sharding)
sharding = cast(jax.sharding.NamedSharding, arg.sharding)
global_meshes.append(sharding.mesh)
mesh_axes.append(sharding.spec)
shardings.append(sharding)
Expand All @@ -163,13 +212,30 @@ async def deserialize(
]
dtypes = [m.dtype if d is None else d for m, d in zip(metadatas, dtypes)]

array_metadatas_cache = {}
if self._array_metadata_store is not None:
if array_metadatas := await self._array_metadata_store.read(
checkpoint_dir=infos[0].parent_dir,
process_index=0,
):
if not isinstance(array_metadatas, list):
raise ValueError(
"Array metadata store returned unexpected result:"
f" {array_metadatas}"
)

array_metadatas_cache = {
array_metadata.param_name: array_metadata
for array_metadata in array_metadatas
}

# Group inputs by global_mesh so that we can perform batched Array
# construction for each global_mesh.
inputs_by_global_mesh = collections.defaultdict(list)
for i, global_mesh in enumerate(global_meshes):
inputs_by_global_mesh[global_mesh].append(i)

results = [None] * len(infos)
results = cast(list[jax.Array], [None] * len(infos))

for global_mesh, idxs in inputs_by_global_mesh.items():
grouped_infos = [infos[idx] for idx in idxs]
Expand All @@ -188,13 +254,26 @@ async def deserialize(
)
# each persistence call is awaited serially.
read_future.result()
for idx, arr in zip(idxs, grouped_arrays):
for idx, info, arr in zip(idxs, grouped_infos, grouped_arrays):
if meta := array_metadatas_cache.get(info.name):
assert isinstance(
meta, array_metadata_lib.SerializedArrayMetadata
), f"Expecting SerializedArrayMetadata but got {type(meta)}."
assert isinstance(meta.ext_metadata, dict), (
"Expecting ext_metadata to be a dict but got"
f" {type(meta.ext_metadata)}."
)

if impl := meta.ext_metadata.get(array_metadata_lib.RANDOM_KEY_IMPL):
arr = jax.random.wrap_key_data(arr, impl=impl)
results[idx] = arr
return results # pytype: disable=bad-return-type

return results


def register_pathways_handlers(
timeout: datetime.timedelta | None = None,
array_metadata_store: array_metadata_store_lib.Store | None = None,
):
"""Function that must be called before saving or restoring with Pathways."""
logger.debug(
Expand All @@ -204,6 +283,7 @@ def register_pathways_handlers(
jax.Array,
CloudPathwaysArrayHandler(
timeout=timeout,
array_metadata_store=array_metadata_store,
),
override=True,
)