From ec8d9828a466bbaec861e9ab7330df7b9219b650 Mon Sep 17 00:00:00 2001 From: lkolluru05 Date: Wed, 2 Jul 2025 05:14:20 +0000 Subject: [PATCH] changes added to test orbax --- axlearn/common/checkpointer.py | 112 ++++++++++------ axlearn/common/checkpointer_orbax.py | 185 ++++++++++++++++++++++----- axlearn/common/module.py | 2 + axlearn/experiments/text/gpt/fuji.py | 37 ++++-- 4 files changed, 260 insertions(+), 76 deletions(-) diff --git a/axlearn/common/checkpointer.py b/axlearn/common/checkpointer.py index 358f43037..7e5c7b11b 100644 --- a/axlearn/common/checkpointer.py +++ b/axlearn/common/checkpointer.py @@ -147,51 +147,60 @@ def filter_for_validation(structure): ) -def _upload_dir(src_dir_handle: tempfile.TemporaryDirectory, *, dst_dir: str): - """Upload a directory (non-recursively) from a temporary dir to dst_dir. - - Temporary dir will be deleted after the upload is complete. - """ - src_dir = src_dir_handle.name - src_files = fs.listdir(src_dir) - # src_files will be empty if there are no tf savables (i.e., don't have any tf state to save). - # In this case, do not create empty dst_dirs. - if len(src_files): - fs.makedirs(dst_dir) - for item in src_files: - src_file = os.path.join(src_dir, item) - dst_file = os.path.join(dst_dir, item) - assert not fs.isdir(src_file) - fs.copy(src_file, dst_file, overwrite=True) - src_dir_handle.cleanup() +# def _upload_dir(src_dir_handle: tempfile.TemporaryDirectory, *, dst_dir: str): +# """Upload a directory (non-recursively) from a temporary dir to dst_dir. + +# Temporary dir will be deleted after the upload is complete. +# """ +# src_dir = src_dir_handle.name +# src_files = fs.listdir(src_dir) +# # src_files will be empty if there are no tf savables (i.e., don't have any tf state to save). +# # In this case, do not create empty dst_dirs. +# if len(src_files): +# fs.makedirs(dst_dir) +# for item in src_files: +# src_file = os.path.join(src_dir, item) +# dst_file = os.path.join(dst_dir, item) +# assert not fs.isdir(src_file) +# fs.copy(src_file, dst_file, overwrite=True) +# src_dir_handle.cleanup() # pylint: disable=redefined-builtin -def async_save_tf_savables( - value_map: Nested[Any], *, executor: futures.ThreadPoolExecutor, dir: str -) -> futures.Future: +def async_save_tf_savables(value_map: Nested[Any], *, dir: str) -> list[tf.train.Checkpoint]: """Asynchronously saves TF savables from `value_map` into `dir`. When this call returns, `value_map` can be safely mutated, but saving to `dir` will not complete unless the returned future is set. """ - # pylint: disable-next=consider-using-with - f = tempfile.TemporaryDirectory() + tf_ckpts = [] for path, value in utils.flatten_items(value_map): tf_checkpoint = tf.train.Checkpoint(value) - tf_checkpoint.write(os.path.join(f.name, path)) - return executor.submit(_upload_dir, f, dst_dir=dir) + tf_checkpoint.write(os.path.join(dir, path), tf.train.CheckpointOptions(enable_async=True)) + tf_ckpts.append(tf_checkpoint) + return tf_ckpts + + +# # pylint: disable-next=redefined-builtin +# def restore_tf_savables(value_map: Nested[Any], *, dir: str) -> Nested[Any]: +# """Restores TF savables from `dir` into `value_map` in-place.""" + +# for path, value in utils.flatten_items(value_map): +# tf_checkpoint = tf.train.Checkpoint(value) +# tf_checkpoint.read(os.path.join(dir, path)) +# return value_map # pylint: disable-next=redefined-builtin -def restore_tf_savables(value_map: Nested[Any], *, dir: str) -> Nested[Any]: +def async_restore_tf_savables(value_map: Nested[Any], *, dir: str) -> list[tf.train.Checkpoint]: """Restores TF savables from `dir` into `value_map` in-place.""" - + ckpts = [] for path, value in utils.flatten_items(value_map): tf_checkpoint = tf.train.Checkpoint(value) - tf_checkpoint.read(os.path.join(dir, path)) + tf_checkpoint.read(os.path.join(dir, path), tf.train.CheckpointOptions(enable_async=True)) + ckpts.append(tf_checkpoint) - return value_map + return ckpts @runtime_checkable @@ -517,11 +526,18 @@ def save_to_dir( # Wait for directory and index creation. multihost_utils.sync_global_devices(ckpt_dir) # Each worker writes its tf checkpoints under a different path. - save_tf_future = async_save_tf_savables( + + # save_tf_future = async_save_tf_savables( + # spec.tf_ckpt_map, + # executor=self._executor, + # dir=os.path.join(ckpt_dir, f"tf_{jax.process_index()}"), + # ) + + tf_ckpts = async_save_tf_savables( spec.tf_ckpt_map, - executor=self._executor, dir=os.path.join(ckpt_dir, f"tf_{jax.process_index()}"), ) + maybe_save_python_savables( spec.python_ckpt_map, dir=os.path.join(ckpt_dir, f"python_{jax.process_index()}") ) @@ -529,8 +545,9 @@ def save_to_dir( def commit(): on_commit_callback(ckpt_dir=ckpt_dir, index=spec.index) logging.info( - "Serialization of %s completed in %s seconds.", - ckpt_dir, + "Checkpointer[%s] saving total time at step %s: %s seconds.", + os.path.basename(os.path.dirname(ckpt_dir)), + step, time.perf_counter() - start_time, ) @@ -542,7 +559,7 @@ def commit(): spec.gda_values, spec.tensorstore_specs, on_commit_callback=commit, - additional_futures=[save_tf_future], + additional_futures=[self._executor.submit(ckpt.sync) for ckpt in tf_ckpts], ) def wait_until_finished(self): @@ -561,13 +578,23 @@ def restore_from_dir( check_state_structure( read_index_file(ckpt_dir), target_structure=spec.index, validation=validation ) - restore_tf_savables( + # restore_tf_savables( + # spec.tf_ckpt_map, dir=os.path.join(ckpt_dir, f"tf_{jax.process_index()}") + # ) + tf_ckpts = async_restore_tf_savables( spec.tf_ckpt_map, dir=os.path.join(ckpt_dir, f"tf_{jax.process_index()}") ) maybe_restore_python_savables( spec.python_ckpt_map, dir=os.path.join(ckpt_dir, f"python_{jax.process_index()}") ) - return self._restore_tensorstore_state(state, ckpt_dir=ckpt_dir, spec=spec) + #return self._restore_tensorstore_state(state, ckpt_dir=ckpt_dir, spec=spec) + + restored_state = self._restore_tensorstore_state(state, ckpt_dir=ckpt_dir, spec=spec) + + for tf_ckpt in tf_ckpts: + tf_ckpt.sync() + + return restored_state def _restore_tensorstore_state( self, state, *, ckpt_dir: str, spec: CheckpointSpec, sync: bool = True @@ -822,7 +849,8 @@ def __enter__(self): This is typically invoked prior to the training loop. """ if self._within_context: - raise ValueError("Already in a context.") + logging.warn("Already in a context.") + #raise ValueError("Already in a context.") self._within_context = True def __exit__( @@ -1042,6 +1070,7 @@ def save( In addition to behavior in `BaseCheckpointer`, saving only happens if the configured checkpoint policy returns True for the given step and evaler summaries. """ + start_time = time.perf_counter() if not self._save_policy(step=step, evaler_summaries=(evaler_summaries or {})): return if step < 0 or step >= 10**8: @@ -1057,6 +1086,14 @@ def save( ckpt_dir=ckpt_dir, action=CheckpointerAction.SAVE, ) + + if jax.process_index() == 0: + logging.info( + "Checkpointer[%s] saving stall time at step %s: %s seconds.", + os.path.basename(os.path.dirname(ckpt_dir)), + step, + time.perf_counter() - start_time, + ) def _run_garbage_collection(self): """Runs one round of garbage collection of past checkpoints. @@ -1100,7 +1137,8 @@ def _run_garbage_collection(self): # For subsequent dirs, non-committed dirs are gc'ed, and committed dirs are kept according # to keep_n_steps. Note that we iterate in order of oldest to newest. - last_kept_step = float("-inf") + #last_kept_step = float("-inf") + last_kept_step = 0 for saved_dir in reversed(dirs[len(remaining_dirs) + len(gc_dirs) :]): saved_step = parse_step_from_dir(saved_dir) if not ( diff --git a/axlearn/common/checkpointer_orbax.py b/axlearn/common/checkpointer_orbax.py index 2eb205b7f..9ed73b3f0 100644 --- a/axlearn/common/checkpointer_orbax.py +++ b/axlearn/common/checkpointer_orbax.py @@ -13,9 +13,12 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import jax +import numpy as np import orbax.checkpoint as ocp import tensorflow as tf from absl import logging +from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib +from orbax.checkpoint._src.serialization.type_handlers import ArrayHandler from axlearn.common import utils from axlearn.common.checkpointer import ( @@ -28,11 +31,14 @@ check_state_structure, maybe_restore_python_savables, maybe_save_python_savables, - restore_tf_savables, + #restore_tf_savables, + async_restore_tf_savables, ) from axlearn.common.config import config_class from axlearn.common.module import Module from axlearn.common.utils import Nested, Tensor, TensorSpec +from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib +from orbax.checkpoint._src.serialization.type_handlers import ArrayHandler try: # The import also registers the checkpoint handlers. @@ -55,6 +61,10 @@ class _TfIteratorHandler(ocp.type_handlers.TypeHandler): instances), we construct and cleanup the executor per-serialize/deserialize call. """ + def __init__(self): + super().__init__() + self._executor = futures.ThreadPoolExecutor() + # Must be a subclass of RestoreArgs for `PyTreeRestore` to recognize it. @dataclasses.dataclass class RestoreArgs(ocp.type_handlers.RestoreArgs): @@ -65,7 +75,8 @@ def typestr(self) -> str: def _ckpt_dir(self, info: ocp.type_handlers.ParamInfo) -> str: # Each worker writes its tf checkpoints under a different path. - return os.path.join(info.parent_dir, f"tf_{jax.process_index()}") + #return os.path.join(info.parent_dir, f"tf_{jax.process_index()}") + return os.path.join(os.path.dirname(info.parent_dir), "tfds", f"tf_{jax.process_index()}") async def serialize( self, @@ -76,15 +87,19 @@ async def serialize( """Serializes `values` into corresponding `info.path`s.""" del args # Unused. futs = [] - with futures.ThreadPoolExecutor(max_workers=1) as executor: - for value, info in zip(values, infos): - futs.append( - async_save_tf_savables( - {info.name: value}, executor=executor, dir=self._ckpt_dir(info) - ) - ) + # with futures.ThreadPoolExecutor(max_workers=1) as executor: + # for value, info in zip(values, infos): + # futs.append( + # async_save_tf_savables( + # {info.name: value}, executor=executor, dir=self._ckpt_dir(info) + # ) + # ) + for value, info in zip(values, infos): + for tf_ckpt in async_save_tf_savables({info.name: value}, dir=self._ckpt_dir(info)): + futs.append(self._executor.submit(tf_ckpt.sync)) return futs + async def deserialize( self, infos: Sequence[ocp.type_handlers.ParamInfo], @@ -92,17 +107,32 @@ async def deserialize( ) -> Sequence[tf.data.Iterator]: if args is None: raise ValueError(f"{self.RestoreArgs.__name__} should be supplied as args.") - futs = [] - with futures.ThreadPoolExecutor(max_workers=1) as executor: - for arg, info in zip(args, infos): + # futs = [] + # with futures.ThreadPoolExecutor(max_workers=1) as executor: + # for arg, info in zip(args, infos): + + # def restore(arg=arg, info=info): + # return restore_tf_savables({info.name: arg.item}, dir=self._ckpt_dir(info))[ + # info.name + # ] + + # futs.append(asyncio.get_event_loop().run_in_executor(executor, restore)) + # return await asyncio.gather(*futs) + iter_ckpts = [ + ckpt + for arg, info in zip(args, infos) + for ckpt in async_restore_tf_savables({info.name: arg.item}, dir=self._ckpt_dir(info)) + ] + + await asyncio.gather( + *( + asyncio.get_event_loop().run_in_executor(self._executor, ckpt.sync) + for ckpt in iter_ckpts + ) + ) - def restore(arg=arg, info=info): - return restore_tf_savables({info.name: arg.item}, dir=self._ckpt_dir(info))[ - info.name - ] + return [arg.item for arg in args] - futs.append(asyncio.get_event_loop().run_in_executor(executor, restore)) - return await asyncio.gather(*futs) async def metadata( self, infos: Sequence[ocp.type_handlers.ParamInfo] @@ -117,6 +147,12 @@ async def metadata( # TODO(markblee): Generalize to PythonSavableHandler. class _GrainDatasetIteratorHandler(ocp.type_handlers.TypeHandler): """Serializes grain dataset iterators.""" + def __init__(self): + super().__init__() + self._executor = futures.ThreadPoolExecutor( + max_workers=1, thread_name_prefix="GrainDatasetIteratorHandler" + ) + @dataclasses.dataclass class RestoreArgs(ocp.type_handlers.RestoreArgs): @@ -127,7 +163,10 @@ def typestr(self) -> str: def _ckpt_dir(self, info: ocp.type_handlers.ParamInfo) -> str: # Each worker writes its grain checkpoints under a different path. - return os.path.join(info.parent_dir, f"python_{jax.process_index()}") + # return os.path.join(info.parent_dir, f"python_{jax.process_index()}") + return os.path.join( + os.path.dirname(info.parent_dir), "python", f"python_{jax.process_index()}" + ) async def serialize( self, @@ -137,9 +176,18 @@ async def serialize( ) -> List[futures.Future]: """Serializes `values` into corresponding `info.path`s.""" del args # Unused. + # for value, info in zip(values, infos): + # maybe_save_python_savables({info.name: value}, dir=self._ckpt_dir(info)) + # return [] + futs = [] for value, info in zip(values, infos): - maybe_save_python_savables({info.name: value}, dir=self._ckpt_dir(info)) - return [] + futs.append( + self._executor.submit( + maybe_save_python_savables, {info.name: value}, dir=self._ckpt_dir(info) + ) + ) + + return futs async def deserialize( self, @@ -148,14 +196,29 @@ async def deserialize( ) -> Sequence[_GrainIterator]: if args is None: raise ValueError(f"{self.RestoreArgs.__name__} should be supplied as args.") - ret = [] - for arg, info in zip(args, infos): - ret.append( - maybe_restore_python_savables({info.name: arg.item}, dir=self._ckpt_dir(info))[ - info.name - ] + # ret = [] + # for arg, info in zip(args, infos): + # ret.append( + # maybe_restore_python_savables({info.name: arg.item}, dir=self._ckpt_dir(info))[ + # info.name + # ] + # ) + # return ret + await asyncio.gather( + *( + asyncio.get_event_loop().run_in_executor( + self._executor, + functools.partial( + maybe_restore_python_savables, + {info.name: arg.item}, + dir=self._ckpt_dir(info), + ), + ) + for arg, info in zip(args, infos) ) - return ret + ) + + return [arg.item for arg in args] async def metadata( self, infos: Sequence[ocp.type_handlers.ParamInfo] @@ -187,15 +250,19 @@ class Config(BaseCheckpointer.Config): Attributes: keep_last_n: Keep this many past ckpts. + keep_every_n_steps: If set, keep a checkpoint every n steps. validation_type: Checkpoint validation during restore. async_timeout_secs: Timeout for async barrier in seconds. """ keep_last_n: int = 1 + keep_period: Optional[int] = None + keep_every_n_steps: Optional[int] = None validation_type: CheckpointValidationType = CheckpointValidationType.EXACT async_timeout_secs: int = 300 max_concurrent_save_gb: Optional[int] = None max_concurrent_restore_gb: Optional[int] = None + enable_single_replica_ckpt_restoring: bool = True @classmethod def checkpoint_paths(cls, base_dir: str) -> List[str]: @@ -237,6 +304,7 @@ def save_fn_with_summaries(step: int, last_saved_step: Optional[int]) -> bool: options=ocp.CheckpointManagerOptions( create=True, max_to_keep=cfg.keep_last_n, + keep_period=cfg.keep_period, enable_async_checkpointing=True, step_name_format=self._name_format, should_save_fn=save_fn_with_summaries, @@ -321,11 +389,33 @@ def restore( cfg: OrbaxCheckpointer.Config = self.config + if cfg.enable_single_replica_ckpt_restoring: + array_handler = ocp.type_handlers.SingleReplicaArrayHandler( + replica_axis_index=1, + # broadcast_memory_limit_bytes=1024 * 1024 * 1000, # 1000 MB limit + ) + ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True) + def _restore_args(x: Any) -> ocp.RestoreArgs: if isinstance(x, (Tensor, TensorSpec)): - return ocp.checkpoint_utils.construct_restore_args( - jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding) - ) + if cfg.enable_single_replica_ckpt_restoring: + pspec = x.sharding.spec + mesh = x.sharding.mesh + replica_axis_index = 0 + replica_devices = _replica_devices(mesh.devices, replica_axis_index) + replica_mesh = jax.sharding.Mesh(replica_devices, mesh.axis_names) + single_replica_sharding = jax.sharding.NamedSharding(replica_mesh, pspec) + + return ocp.type_handlers.SingleReplicaArrayRestoreArgs( + sharding=jax.sharding.NamedSharding(mesh, pspec), + single_replica_sharding=single_replica_sharding, + global_shape=x.shape, + dtype=x.dtype, + ) + else: + return ocp.checkpoint_utils.construct_restore_args( + jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding) + ) elif isinstance(x, tf.data.Iterator): return _TfIteratorHandler.RestoreArgs(item=x) elif _GRAIN_INSTALLED and isinstance(x, _GrainIterator): @@ -349,6 +439,13 @@ def _restore_args(x: Any) -> ocp.RestoreArgs: raise ValueError(f"Failed to restore at step {step}.") from e logging.info("Could not find any completed checkpoints under %s: %s", cfg.dir, e) return None, state # Return the input state. + finally: + if cfg.enable_single_replica_ckpt_restoring: + ocp.type_handlers.register_type_handler( + jax.Array, + ArrayHandler(array_metadata_store=array_metadata_store_lib.Store()), + override=True, + ) restored_index = composite_state["index"] restored_state = composite_state["state"] @@ -375,3 +472,29 @@ def wait_until_finished(self): def stop(self, *, has_exception: bool = False): """See `BaseCheckpointer.stop` for details.""" self._manager.close() + + +def _find_idx(array: np.ndarray, replica_axis_idx: int): + """Returns the index along given dimension that the current host belongs to.""" + idx = None + for idx, val in np.ndenumerate(array): + if val.process_index == jax.process_index(): + break + return idx[replica_axis_idx] + + +def _replica_devices(device_array: np.ndarray, replica_axis_idx: int): + """Returns the devices from the replica that current host belongs to. + + Replicas are assumed to be restricted to the first axis. + + Args: + device_array: devices of the mesh that can be obtained by mesh.devices() + replica_axis_idx: axis dimension along which replica is taken + + Returns: + devices inside the replica that current host is in + """ + idx = _find_idx(device_array, replica_axis_idx) + replica_result = np.take(device_array, idx, axis=replica_axis_idx) + return np.expand_dims(replica_result, axis=replica_axis_idx) diff --git a/axlearn/common/module.py b/axlearn/common/module.py index dacbc62ba..32fd48b68 100644 --- a/axlearn/common/module.py +++ b/axlearn/common/module.py @@ -597,6 +597,8 @@ def wrap_method_fn(self, *args, **kwargs): try: return method_fn_in_context(self, *args, **kwargs) except TypeError as e: + # pylint: disable-next=logging-not-lazy + logging.warning("!!!!!" + traceback.format_exc()) # Make it easier to see what call triggered the error in CI. # When running in an environment like TPUs where stack summaries are available, # this is unecessary and we would have slightly cleaner summaries without it. diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 57c910c70..533e7960b 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -366,6 +366,10 @@ def get_trainer_kwargs( ), ) elif model_size == "7B": + # pylint: disable=import-outside-toplevel + import jax + + gbs = len(jax.devices()) trainer_kwargs = dict( model_kwargs=dict( num_layers=32, @@ -378,9 +382,9 @@ def get_trainer_kwargs( ), learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, - train_batch_size=train_batch_size, + train_batch_size=gbs, max_step=max_step, - mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), + mesh_shape=mesh_shape_from_axes(data=1, fsdp=256), mesh_rules=( # Step time: # v1 on tpu-v4-1024 (512 chips): 3.03s @@ -633,6 +637,9 @@ def get_trainer_kwargs( ), ) elif model_size == "70B": + import jax + + gbs = len(jax.devices()) trainer_kwargs = dict( model_kwargs=dict( num_layers=80, @@ -648,9 +655,9 @@ def get_trainer_kwargs( ), learner_kwargs=dict(peak_lr=1.5e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, - train_batch_size=train_batch_size, + train_batch_size=gbs,#train_batch_size, max_step=max_step, - mesh_shape=mesh_shape_from_axes(fsdp=-1), + mesh_shape=mesh_shape_from_axes(data=1, fsdp=256), mesh_rules=( # TPU V5e maximum per device batch is 1. # with all activation offloading, HBM usage: 14.6GB/chip. @@ -679,7 +686,7 @@ def get_trainer_kwargs( ChainConfigModifier.default_config().set( config_modifiers=[ MeshShapeModifier.default_config().set( - mesh_shape=mesh_shape_from_axes(fsdp=-1) + mesh_shape=mesh_shape_from_axes(data=-1, fsdp=64) ), RematSpecModifier.default_config().set( remat_policies={ @@ -914,17 +921,30 @@ def trainer_configs( """ arch = "fuji" config_map = {} - for version, model_size, flash_attention in itertools.product( - Version, MODEL_SIZES, [True, False] + for version, model_size, flash_attention, checkpointer in itertools.product( + Version, + MODEL_SIZES, + [True, False], + ["", "OrbaxEmergencyCheckpointer", "OrbaxRegularCheckpointer"], ): if model_size not in TOTAL_TOKENS[version]: # This combination does not exist. continue vocab_size = VOCAB_SIZE[version] + + current_suffix_parts = [] + if flash_attention: + current_suffix_parts.append("-flash") + if checkpointer == "OrbaxEmergencyCheckpointer": + current_suffix_parts.append("-orbaxem") + elif checkpointer == "OrbaxRegularCheckpointer": + current_suffix_parts.append("-orbax") + current_suffix = "".join(current_suffix_parts) + config_name = make_config_name( arch=arch, model_size=model_size, version=f"v{version.value}", - suffix="-flash" if flash_attention else "", + suffix=current_suffix, ) kwargs = get_trainer_kwargs( model_size, vocab_size=vocab_size, version=version, flash_attention=flash_attention @@ -939,6 +959,7 @@ def trainer_configs( evalers=evaler_config_dict( eval_input_sources(vocab_size=vocab_size, max_sequence_length=max_sequence_length), ), + checkpointer=checkpointer, **kwargs, )