From 6005b6df49bcbd84d2ef3e13cf1c999ce46b8f1d Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 20 Jun 2025 11:28:24 -0700 Subject: [PATCH 01/57] add test script --- test-orbax.sh | 75 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100755 test-orbax.sh diff --git a/test-orbax.sh b/test-orbax.sh new file mode 100755 index 000000000..b44c0aaca --- /dev/null +++ b/test-orbax.sh @@ -0,0 +1,75 @@ +#!/usr/bin/env bash + +set -xe + +export NUM_REPLICAS=${NUM_REPLICAS:-2} +export JOBSET_NAME=${JOBSET_NAME:-$USER} +export BASTION_TIER=disabled +export GKE_CLUSTER=$(axlearn gcp config | grep gke_cluster | awk '{ print $3 }' | tr -d '"') +# Switch to tpu-v6e-256 if on scale cluster +export INSTANCE_TYPE=${INSTANCE_TYPE:-"tpu-v6e-16"} +# Switch to tpu-v6e-256-4 if on scale cluster +export MESH_SELECTOR=${MESH:-"tpu-v6e-16"} +export CONFIG=${CONFIG:-"fuji-7B-v3-flash-orbaxem"} +export PROJECT_ID=$(gcloud config get project) + +# Example for v6e-256 +# MESH_SELECTOR=tpu-v6e-256-4 INSTANCE_TYPE=tpu-v6e-256 ./test-orbax.sh + +# The bundle step is needed if you run on cloudtop +# uncomment if you use cloudtop +axlearn gcp bundle --name=$JOBSET_NAME \ + --bundler_spec=allow_dirty=True \ + --bundler_type=artifactregistry \ + --bundler_spec=dockerfile=Dockerfile \ + --bundler_spec=image=tpu \ + --bundler_spec=target=tpu + +# Only enable kueue when running on scale testing cluster +# --queue=multislice-queue \ +# --priority_class=very-high \ +# --trainer_dir=gs://tess-checkpoints-us-west1/${JOBSET_NAME}-nr-${NUM_REPLICAS}/ \ + +# Check if CONFIG ends with "orbaxem" +if [[ "$CONFIG" == *"orbaxem"* ]]; then + echo "Running with Orbax emergency checkpointer." + axlearn gcp launch run --cluster=$GKE_CLUSTER \ + --runner_name gke_tpu_single \ + --name=$JOBSET_NAME \ + --instance_type=${INSTANCE_TYPE} \ + --host_mount_spec=name=tmp,host_path=/tmp,mount_path=/host-tmp \ + --num_replicas=${NUM_REPLICAS} \ + --bundler_spec=allow_dirty=True \ + --bundler_type=artifactregistry --bundler_spec=image=tpu \ + --bundler_spec=dockerfile=Dockerfile --bundler_spec=target=tpu \ + -- "ulimit -n 1048576; ulimit -c 0; python3 -c 'import jax; jax.devices()'; python3 -m axlearn.common.launch_trainer_main" \ + --init_module=axlearn.common.checkpointer_orbax_emergency:local_ckpt_dir=/host-tmp/checkpoints \ + --module=text.gpt.c4_trainer \ + --config=${CONFIG} \ + --trainer_dir=gs://${PROJECT_ID}-axlearn/${JOBSET_NAME}-nr-${NUM_REPLICAS}/ \ + --data_dir=gs://axlearn-public/tensorflow_datasets \ + --jax_backend=tpu \ + --mesh_selector=${MESH_SELECTOR} \ + --initialization_timeout=1200 \ + --trace_at_steps=29,59,89,119,149,179,209,239,269,299,329,359,389,419,449,479,509,539,569,599,629,659,689,719 + +else + echo "Running without Orbax emergency checkpointer." + axlearn gcp launch run --cluster=$GKE_CLUSTER \ + --runner_name gke_tpu_single \ + --name=$JOBSET_NAME \ + --instance_type=${INSTANCE_TYPE} \ + --num_replicas=${NUM_REPLICAS} \ + --bundler_spec=allow_dirty=True \ + --bundler_type=artifactregistry --bundler_spec=image=tpu \ + --bundler_spec=dockerfile=Dockerfile --bundler_spec=target=tpu \ + -- "python3 -c 'import jax; jax.devices()'; python3 -m axlearn.common.launch_trainer_main" \ + --module=text.gpt.c4_trainer \ + --config=${CONFIG} \ + --trainer_dir=gs://${PROJECT_ID}-axlearn/${JOBSET_NAME}-nr-${NUM_REPLICAS}/ \ + --data_dir=gs://axlearn-public/tensorflow_datasets \ + --jax_backend=tpu \ + --mesh_selector=${MESH_SELECTOR} \ + --initialization_timeout=1200 \ + --trace_at_steps=29,59,89,119,149,179,209,239,269,299,329,359,389,419,449,479,509,539,569,599,629,659,689,719 +fi From 40d1ed5143a1b5b5995f0c65f9c957cdad8540a3 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 20 Jun 2025 11:48:23 -0700 Subject: [PATCH 02/57] Jun Orbax regular checkpointer fixes --- axlearn/common/checkpointer_orbax.py | 63 ++++++++++++++++++++++++++-- 1 file changed, 60 insertions(+), 3 deletions(-) diff --git a/axlearn/common/checkpointer_orbax.py b/axlearn/common/checkpointer_orbax.py index 2eb205b7f..2c2ee2108 100644 --- a/axlearn/common/checkpointer_orbax.py +++ b/axlearn/common/checkpointer_orbax.py @@ -13,9 +13,12 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import jax +import numpy as np import orbax.checkpoint as ocp import tensorflow as tf from absl import logging +from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib +from orbax.checkpoint._src.serialization.type_handlers import ArrayHandler from axlearn.common import utils from axlearn.common.checkpointer import ( @@ -196,6 +199,7 @@ class Config(BaseCheckpointer.Config): async_timeout_secs: int = 300 max_concurrent_save_gb: Optional[int] = None max_concurrent_restore_gb: Optional[int] = None + enable_single_replica_ckpt_restoring: bool = True @classmethod def checkpoint_paths(cls, base_dir: str) -> List[str]: @@ -321,11 +325,33 @@ def restore( cfg: OrbaxCheckpointer.Config = self.config + if cfg.enable_single_replica_ckpt_restoring: + array_handler = ocp.type_handlers.SingleReplicaArrayHandler( + replica_axis_index=0, + broadcast_memory_limit_bytes=1024 * 1024 * 1000, # 1000 MB limit + ) + ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True) + def _restore_args(x: Any) -> ocp.RestoreArgs: if isinstance(x, (Tensor, TensorSpec)): - return ocp.checkpoint_utils.construct_restore_args( - jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding) - ) + if cfg.enable_single_replica_ckpt_restoring: + pspec = x.sharding.spec + mesh = x.sharding.mesh + replica_axis_index = 0 + replica_devices = _replica_devices(mesh.devices, replica_axis_index) + replica_mesh = jax.sharding.Mesh(replica_devices, mesh.axis_names) + single_replica_sharding = jax.sharding.NamedSharding(replica_mesh, pspec) + + return ocp.type_handlers.SingleReplicaArrayRestoreArgs( + sharding=jax.sharding.NamedSharding(mesh, pspec), + single_replica_sharding=single_replica_sharding, + global_shape=x.shape, + dtype=x.dtype, + ) + else: + return ocp.checkpoint_utils.construct_restore_args( + jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding) + ) elif isinstance(x, tf.data.Iterator): return _TfIteratorHandler.RestoreArgs(item=x) elif _GRAIN_INSTALLED and isinstance(x, _GrainIterator): @@ -349,6 +375,11 @@ def _restore_args(x: Any) -> ocp.RestoreArgs: raise ValueError(f"Failed to restore at step {step}.") from e logging.info("Could not find any completed checkpoints under %s: %s", cfg.dir, e) return None, state # Return the input state. + finally: + if cfg.enable_single_replica_ckpt_restoring: + ocp.type_handlers.register_type_handler( + jax.Array, ArrayHandler(array_metadata_store=array_metadata_store_lib.Store()) + ) restored_index = composite_state["index"] restored_state = composite_state["state"] @@ -375,3 +406,29 @@ def wait_until_finished(self): def stop(self, *, has_exception: bool = False): """See `BaseCheckpointer.stop` for details.""" self._manager.close() + + +def _find_idx(array: np.ndarray, replica_axis_idx: int): + """Returns the index along given dimension that the current host belongs to.""" + idx = None + for idx, val in np.ndenumerate(array): + if val.process_index == jax.process_index(): + break + return idx[replica_axis_idx] + + +def _replica_devices(device_array: np.ndarray, replica_axis_idx: int): + """Returns the devices from the replica that current host belongs to. + + Replicas are assumed to be restricted to the first axis. + + Args: + device_array: devices of the mesh that can be obtained by mesh.devices() + replica_axis_idx: axis dimension along which replica is taken + + Returns: + devices inside the replica that current host is in + """ + idx = _find_idx(device_array, replica_axis_idx) + replica_result = np.take(device_array, idx, axis=replica_axis_idx) + return np.expand_dims(replica_result, axis=replica_axis_idx) From b9194c1be64f4bebc4f7e45b21657741cc8cb8fe Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Thu, 22 May 2025 13:16:24 -0700 Subject: [PATCH 03/57] Orbax emergency trainer config for Fuji --- axlearn/experiments/text/gpt/common.py | 39 ++++++++++++++++++++++---- axlearn/experiments/text/gpt/fuji.py | 16 +++++++++-- 2 files changed, 46 insertions(+), 9 deletions(-) diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index 57d606dab..0239a6b8d 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -643,6 +643,7 @@ def get_trainer_config_fn( keep_every_n_steps: int = 50_000, save_every_n_steps: Optional[int] = None, init_state_builder: Optional[state_builder.Builder.Config] = None, + checkpointer: str = "", ) -> TrainerConfigFn: """Builds a TrainerConfigFn according to the model and input specs. @@ -709,13 +710,39 @@ def config_fn() -> InstantiableConfig: ) ) cfg.evalers[name] = evaler_cfg + # Summaries and checkpoints. - cfg.checkpointer.save_policy = config_for_function(every_n_steps_and_last_policy).set( - n=save_every_n_steps or min(eval_every_n_steps, 5_000), - max_step=max_step, - ) - cfg.checkpointer.keep_every_n_steps = min(max_step, keep_every_n_steps) - cfg.checkpointer.keep_last_n = 3 + calculated_save_every_n_steps = save_every_n_steps or min(eval_every_n_steps, 5_000) + + if not checkpointer: + cfg.checkpointer.save_policy = config_for_function(every_n_steps_and_last_policy).set( + n=calculated_save_every_n_steps, + max_step=max_step, + ) + cfg.checkpointer.keep_every_n_steps = min(max_step, keep_every_n_steps) + cfg.checkpointer.keep_last_n = 3 + elif checkpointer == "OrbaxEmergencyCheckpointer": + # Prevent global dependency on Orbax. + # pylint: disable-next=import-outside-toplevel + from axlearn.common.checkpointer_orbax_emergency import OrbaxEmergencyCheckpointer + + ckpt_config: OrbaxEmergencyCheckpointer.Config = ( + OrbaxEmergencyCheckpointer.default_config() + ) + ckpt_config.save_policy = config_for_function(every_n_steps_and_last_policy).set( + n=calculated_save_every_n_steps, + max_step=max_step, + ) + ckpt_config.local_save_policy = config_for_function(every_n_steps_and_last_policy).set( + n=calculated_save_every_n_steps, + max_step=max_step, + ) + ckpt_config.local_dir = "/host-tmp/checkpoints" + ckpt_config.keep_every_n_steps = min(max_step, keep_every_n_steps) + ckpt_config.keep_last_n = 3 + ckpt_config.replica_axis_index = 1 + cfg.checkpointer = ckpt_config + cfg.summary_writer.write_every_n_steps = min(eval_every_n_steps, 100) cfg.summary_writer.max_queue = 1000 if len(mesh_axis_names) != len(mesh_shape): diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 57c910c70..9d61dcfa7 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -914,22 +914,31 @@ def trainer_configs( """ arch = "fuji" config_map = {} - for version, model_size, flash_attention in itertools.product( - Version, MODEL_SIZES, [True, False] + for version, model_size, flash_attention, use_orbax_emergency_ckpt in itertools.product( + Version, MODEL_SIZES, [True, False], [False, True] ): if model_size not in TOTAL_TOKENS[version]: # This combination does not exist. continue vocab_size = VOCAB_SIZE[version] + + current_suffix_parts = [] + if flash_attention: + current_suffix_parts.append("-flash") + if use_orbax_emergency_ckpt: + current_suffix_parts.append("-orbaxem") + current_suffix = "".join(current_suffix_parts) + config_name = make_config_name( arch=arch, model_size=model_size, version=f"v{version.value}", - suffix="-flash" if flash_attention else "", + suffix=current_suffix, ) kwargs = get_trainer_kwargs( model_size, vocab_size=vocab_size, version=version, flash_attention=flash_attention ) max_sequence_length = kwargs.pop("max_sequence_length") + checkpointer_str = "OrbaxEmergencyCheckpointer" if use_orbax_emergency_ckpt else "" # pylint: disable-next=unexpected-keyword-arg,missing-kwoa config_map[config_name] = get_trainer_config_fn( train_input_source=train_input_source( @@ -939,6 +948,7 @@ def trainer_configs( evalers=evaler_config_dict( eval_input_sources(vocab_size=vocab_size, max_sequence_length=max_sequence_length), ), + checkpointer=checkpointer_str, **kwargs, ) From 74f7522f31ab3c27e85c926f2c91c76df8890119 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 20 Jun 2025 12:32:28 -0700 Subject: [PATCH 04/57] update fuji configs to use regular orbax checkpointer --- axlearn/experiments/text/gpt/common.py | 13 +++++++++++++ axlearn/experiments/text/gpt/fuji.py | 14 +++++++++----- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index 0239a6b8d..90e08d58c 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -742,6 +742,19 @@ def config_fn() -> InstantiableConfig: ckpt_config.keep_last_n = 3 ckpt_config.replica_axis_index = 1 cfg.checkpointer = ckpt_config + elif checkpointer == "OrbaxRegularCheckpointer": + # Prevent global dependency on Orbax. + # pylint: disable-next=import-outside-toplevel + from axlearn.common.checkpointer_orbax import OrbaxCheckpointer + + ckpt_config: OrbaxCheckpointer.Config = OrbaxCheckpointer.default_config() + ckpt_config.save_policy = config_for_function(every_n_steps_and_last_policy).set( + n=calculated_save_every_n_steps, + max_step=max_step, + ) + ckpt_config.keep_every_n_steps = min(max_step, keep_every_n_steps) + ckpt_config.keep_last_n = 3 + cfg.checkpointer = ckpt_config cfg.summary_writer.write_every_n_steps = min(eval_every_n_steps, 100) cfg.summary_writer.max_queue = 1000 diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 9d61dcfa7..07eebddd8 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -914,8 +914,11 @@ def trainer_configs( """ arch = "fuji" config_map = {} - for version, model_size, flash_attention, use_orbax_emergency_ckpt in itertools.product( - Version, MODEL_SIZES, [True, False], [False, True] + for version, model_size, flash_attention, checkpointer in itertools.product( + Version, + MODEL_SIZES, + [True, False], + ["", "OrbaxEmergencyCheckpointer", "OrbaxRegularCheckpointer"], ): if model_size not in TOTAL_TOKENS[version]: # This combination does not exist. continue @@ -924,8 +927,10 @@ def trainer_configs( current_suffix_parts = [] if flash_attention: current_suffix_parts.append("-flash") - if use_orbax_emergency_ckpt: + if checkpointer == "OrbaxEmergencyCheckpointer": current_suffix_parts.append("-orbaxem") + elif checkpointer == "OrbaxRegularCheckpointer": + current_suffix_parts.append("-orbax") current_suffix = "".join(current_suffix_parts) config_name = make_config_name( @@ -938,7 +943,6 @@ def trainer_configs( model_size, vocab_size=vocab_size, version=version, flash_attention=flash_attention ) max_sequence_length = kwargs.pop("max_sequence_length") - checkpointer_str = "OrbaxEmergencyCheckpointer" if use_orbax_emergency_ckpt else "" # pylint: disable-next=unexpected-keyword-arg,missing-kwoa config_map[config_name] = get_trainer_config_fn( train_input_source=train_input_source( @@ -948,7 +952,7 @@ def trainer_configs( evalers=evaler_config_dict( eval_input_sources(vocab_size=vocab_size, max_sequence_length=max_sequence_length), ), - checkpointer=checkpointer_str, + checkpointer=checkpointer, **kwargs, ) From 8e8093120db1f85e299763da6e4d73af2f21b35a Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 20 Jun 2025 13:36:37 -0700 Subject: [PATCH 05/57] support keep_period --- axlearn/common/checkpointer_orbax.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/axlearn/common/checkpointer_orbax.py b/axlearn/common/checkpointer_orbax.py index 2c2ee2108..5aa8f72a2 100644 --- a/axlearn/common/checkpointer_orbax.py +++ b/axlearn/common/checkpointer_orbax.py @@ -190,11 +190,13 @@ class Config(BaseCheckpointer.Config): Attributes: keep_last_n: Keep this many past ckpts. + keep_every_n_steps: If set, keep a checkpoint every n steps. validation_type: Checkpoint validation during restore. async_timeout_secs: Timeout for async barrier in seconds. """ keep_last_n: int = 1 + keep_every_n_steps: Optional[int] = None validation_type: CheckpointValidationType = CheckpointValidationType.EXACT async_timeout_secs: int = 300 max_concurrent_save_gb: Optional[int] = None @@ -241,6 +243,7 @@ def save_fn_with_summaries(step: int, last_saved_step: Optional[int]) -> bool: options=ocp.CheckpointManagerOptions( create=True, max_to_keep=cfg.keep_last_n, + keep_period=cfg.keep_every_n_steps, enable_async_checkpointing=True, step_name_format=self._name_format, should_save_fn=save_fn_with_summaries, From 10d05765ee8ef309478789b4517ec06369426fda Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 20 Jun 2025 13:37:12 -0700 Subject: [PATCH 06/57] support run for orbax regular checkpointer --- Dockerfile | 2 +- test-orbax.sh | 19 ++++++++++--------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/Dockerfile b/Dockerfile index 29db664d3..dc011dcb0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -83,7 +83,7 @@ ENTRYPOINT ["/opt/apache/beam/boot"] FROM base AS tpu -ARG EXTRAS= +ARG EXTRAS=orbax ENV UV_FIND_LINKS=https://storage.googleapis.com/jax-releases/libtpu_releases.html # Ensure we install the TPU version, even if building locally. diff --git a/test-orbax.sh b/test-orbax.sh index b44c0aaca..b44e5ba49 100755 --- a/test-orbax.sh +++ b/test-orbax.sh @@ -10,7 +10,7 @@ export GKE_CLUSTER=$(axlearn gcp config | grep gke_cluster | awk '{ print $3 }' export INSTANCE_TYPE=${INSTANCE_TYPE:-"tpu-v6e-16"} # Switch to tpu-v6e-256-4 if on scale cluster export MESH_SELECTOR=${MESH:-"tpu-v6e-16"} -export CONFIG=${CONFIG:-"fuji-7B-v3-flash-orbaxem"} +export CONFIG=${CONFIG:-"fuji-7B-v3-flash-orbax"} export PROJECT_ID=$(gcloud config get project) # Example for v6e-256 @@ -18,17 +18,18 @@ export PROJECT_ID=$(gcloud config get project) # The bundle step is needed if you run on cloudtop # uncomment if you use cloudtop -axlearn gcp bundle --name=$JOBSET_NAME \ - --bundler_spec=allow_dirty=True \ - --bundler_type=artifactregistry \ - --bundler_spec=dockerfile=Dockerfile \ - --bundler_spec=image=tpu \ - --bundler_spec=target=tpu +# axlearn gcp bundle --name=$JOBSET_NAME \ +# --bundler_spec=allow_dirty=True \ +# --bundler_type=artifactregistry \ +# --bundler_spec=dockerfile=Dockerfile \ +# --bundler_spec=image=tpu \ +# --bundler_spec=target=tpu # Only enable kueue when running on scale testing cluster # --queue=multislice-queue \ # --priority_class=very-high \ # --trainer_dir=gs://tess-checkpoints-us-west1/${JOBSET_NAME}-nr-${NUM_REPLICAS}/ \ +# # Check if CONFIG ends with "orbaxem" if [[ "$CONFIG" == *"orbaxem"* ]]; then @@ -54,7 +55,7 @@ if [[ "$CONFIG" == *"orbaxem"* ]]; then --trace_at_steps=29,59,89,119,149,179,209,239,269,299,329,359,389,419,449,479,509,539,569,599,629,659,689,719 else - echo "Running without Orbax emergency checkpointer." + echo "Running Orbax regular checkpointer or AXLearn native." axlearn gcp launch run --cluster=$GKE_CLUSTER \ --runner_name gke_tpu_single \ --name=$JOBSET_NAME \ @@ -63,7 +64,7 @@ else --bundler_spec=allow_dirty=True \ --bundler_type=artifactregistry --bundler_spec=image=tpu \ --bundler_spec=dockerfile=Dockerfile --bundler_spec=target=tpu \ - -- "python3 -c 'import jax; jax.devices()'; python3 -m axlearn.common.launch_trainer_main" \ + -- "ulimit -n 1048576; ulimit -c 0; python3 -c 'import jax; jax.devices()'; python3 -m axlearn.common.launch_trainer_main" \ --module=text.gpt.c4_trainer \ --config=${CONFIG} \ --trainer_dir=gs://${PROJECT_ID}-axlearn/${JOBSET_NAME}-nr-${NUM_REPLICAS}/ \ From abbe0acd3cd449d0c78ec84743f8b0f9abcac9c7 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 20 Jun 2025 13:56:30 -0700 Subject: [PATCH 07/57] fix for A TypeHandler for "" is already registered. https://buganizer.corp.google.com/issues/419599840#comment18 --- axlearn/common/checkpointer_orbax.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/axlearn/common/checkpointer_orbax.py b/axlearn/common/checkpointer_orbax.py index 5aa8f72a2..c96c48d17 100644 --- a/axlearn/common/checkpointer_orbax.py +++ b/axlearn/common/checkpointer_orbax.py @@ -381,7 +381,9 @@ def _restore_args(x: Any) -> ocp.RestoreArgs: finally: if cfg.enable_single_replica_ckpt_restoring: ocp.type_handlers.register_type_handler( - jax.Array, ArrayHandler(array_metadata_store=array_metadata_store_lib.Store()) + jax.Array, + ArrayHandler(array_metadata_store=array_metadata_store_lib.Store()), + override=True, ) restored_index = composite_state["index"] From 129281995dd1c6398f147f6ce2bcda7f5a053890 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 20 Jun 2025 14:27:37 -0700 Subject: [PATCH 08/57] pdbs=1 and print every step --- axlearn/common/trainer.py | 2 +- axlearn/experiments/text/gpt/fuji.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/axlearn/common/trainer.py b/axlearn/common/trainer.py index 0603f7bf9..4d501f64a 100644 --- a/axlearn/common/trainer.py +++ b/axlearn/common/trainer.py @@ -613,7 +613,7 @@ def run( ) self.vlog(3, "Done step %s", self.step) num_steps += 1 - if num_steps % 100 == 0: + if num_steps % 1 == 0: now = time.perf_counter() average_step_time = (now - start_time) / num_steps self._step_log("Average step time: %s seconds", average_step_time) diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 07eebddd8..c4e0eac14 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -366,6 +366,10 @@ def get_trainer_kwargs( ), ) elif model_size == "7B": + # pylint: disable=import-outside-toplevel + import jax + + gbs = len(jax.devices()) trainer_kwargs = dict( model_kwargs=dict( num_layers=32, @@ -378,7 +382,7 @@ def get_trainer_kwargs( ), learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, - train_batch_size=train_batch_size, + train_batch_size=gbs, max_step=max_step, mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), mesh_rules=( From 073cbfafa20aa393dd81a6e5403ea043e4a23eed Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 20 Jun 2025 14:42:28 -0700 Subject: [PATCH 09/57] checkpoint every 100 steps --- axlearn/experiments/text/gpt/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index 90e08d58c..015a59b88 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -712,7 +712,7 @@ def config_fn() -> InstantiableConfig: cfg.evalers[name] = evaler_cfg # Summaries and checkpoints. - calculated_save_every_n_steps = save_every_n_steps or min(eval_every_n_steps, 5_000) + calculated_save_every_n_steps = save_every_n_steps or min(eval_every_n_steps, 100) if not checkpointer: cfg.checkpointer.save_policy = config_for_function(every_n_steps_and_last_policy).set( From d723cc734ea9642db1bb8b6957adcdc21a2b5d9f Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 20 Jun 2025 15:34:39 -0700 Subject: [PATCH 10/57] increase termination Grace Period to 300s --- axlearn/cloud/gcp/jobset_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/axlearn/cloud/gcp/jobset_utils.py b/axlearn/cloud/gcp/jobset_utils.py index ab3a7daaf..736656ab6 100644 --- a/axlearn/cloud/gcp/jobset_utils.py +++ b/axlearn/cloud/gcp/jobset_utils.py @@ -690,7 +690,7 @@ def _build_pod(self) -> Nested[Any]: spec = dict( # NOTE: Don't set hostNetwork or dnsPolicy for compat with Workload Identity. - terminationGracePeriodSeconds=60, + terminationGracePeriodSeconds=300, # Fail if any pod fails, and allow retries to happen at JobSet level. restartPolicy="Never", # https://kubernetes.io/docs/tasks/network/customize-hosts-file-for-pods/#adding-additional-entries-with-hostaliases From 383d0bc8d21e8aca5e94fdf8cb707e83f29a1f6c Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 20 Jun 2025 15:45:58 -0700 Subject: [PATCH 11/57] termination grace period to 900s --- axlearn/cloud/gcp/jobset_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/axlearn/cloud/gcp/jobset_utils.py b/axlearn/cloud/gcp/jobset_utils.py index 736656ab6..83c66a916 100644 --- a/axlearn/cloud/gcp/jobset_utils.py +++ b/axlearn/cloud/gcp/jobset_utils.py @@ -690,7 +690,7 @@ def _build_pod(self) -> Nested[Any]: spec = dict( # NOTE: Don't set hostNetwork or dnsPolicy for compat with Workload Identity. - terminationGracePeriodSeconds=300, + terminationGracePeriodSeconds=900, # Fail if any pod fails, and allow retries to happen at JobSet level. restartPolicy="Never", # https://kubernetes.io/docs/tasks/network/customize-hosts-file-for-pods/#adding-additional-entries-with-hostaliases From 90f1e2739abad81218437190892a344b34feecf7 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 20 Jun 2025 16:25:56 -0700 Subject: [PATCH 12/57] Revert "termination grace period to 900s" This reverts commit 383d0bc8d21e8aca5e94fdf8cb707e83f29a1f6c. --- axlearn/cloud/gcp/jobset_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/axlearn/cloud/gcp/jobset_utils.py b/axlearn/cloud/gcp/jobset_utils.py index 83c66a916..736656ab6 100644 --- a/axlearn/cloud/gcp/jobset_utils.py +++ b/axlearn/cloud/gcp/jobset_utils.py @@ -690,7 +690,7 @@ def _build_pod(self) -> Nested[Any]: spec = dict( # NOTE: Don't set hostNetwork or dnsPolicy for compat with Workload Identity. - terminationGracePeriodSeconds=900, + terminationGracePeriodSeconds=300, # Fail if any pod fails, and allow retries to happen at JobSet level. restartPolicy="Never", # https://kubernetes.io/docs/tasks/network/customize-hosts-file-for-pods/#adding-additional-entries-with-hostaliases From 871d2dbf190c35a7042c3d876f67acd36944e6eb Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Wed, 25 Jun 2025 08:19:55 -0700 Subject: [PATCH 13/57] save the data iterator --- axlearn/experiments/text/gpt/common.py | 3 +++ test-orbax.sh | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index 015a59b88..799c6ed25 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -756,6 +756,9 @@ def config_fn() -> InstantiableConfig: ckpt_config.keep_last_n = 3 cfg.checkpointer = ckpt_config + # Save the data iterator as part of the checkpointing process. + cfg.save_input_iterator = True + cfg.summary_writer.write_every_n_steps = min(eval_every_n_steps, 100) cfg.summary_writer.max_queue = 1000 if len(mesh_axis_names) != len(mesh_shape): diff --git a/test-orbax.sh b/test-orbax.sh index b44e5ba49..2cd91feb1 100755 --- a/test-orbax.sh +++ b/test-orbax.sh @@ -10,7 +10,7 @@ export GKE_CLUSTER=$(axlearn gcp config | grep gke_cluster | awk '{ print $3 }' export INSTANCE_TYPE=${INSTANCE_TYPE:-"tpu-v6e-16"} # Switch to tpu-v6e-256-4 if on scale cluster export MESH_SELECTOR=${MESH:-"tpu-v6e-16"} -export CONFIG=${CONFIG:-"fuji-7B-v3-flash-orbax"} +export CONFIG=${CONFIG:-"fuji-7B-v3-tiktoken-flash-orbax"} export PROJECT_ID=$(gcloud config get project) # Example for v6e-256 From bd25cff13af2512bac19e13195b608a5a63625ce Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Mon, 21 Jul 2025 10:47:35 -0700 Subject: [PATCH 14/57] use fuji v3 8b for tiktoken --- test-orbax.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test-orbax.sh b/test-orbax.sh index 2cd91feb1..d4072c5f9 100755 --- a/test-orbax.sh +++ b/test-orbax.sh @@ -10,7 +10,8 @@ export GKE_CLUSTER=$(axlearn gcp config | grep gke_cluster | awk '{ print $3 }' export INSTANCE_TYPE=${INSTANCE_TYPE:-"tpu-v6e-16"} # Switch to tpu-v6e-256-4 if on scale cluster export MESH_SELECTOR=${MESH:-"tpu-v6e-16"} -export CONFIG=${CONFIG:-"fuji-7B-v3-tiktoken-flash-orbax"} +# Need to use tiktoken when saving data iterator +export CONFIG=${CONFIG:-"fuji-8B-v3-tiktoken-flash-orbax"} export PROJECT_ID=$(gcloud config get project) # Example for v6e-256 From 8923ca60b712b85bc045600684c91de4f541cece Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Mon, 28 Jul 2025 14:31:03 -0700 Subject: [PATCH 15/57] use tokenizers instead of tokenizer --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 57e415e7e..100976d54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ core = [ "protobuf>=3.20.3", "tensorboard-plugin-profile==2.15.1", # This has both x86 and arm64 wheels. Underneath the hood it uses tensorflow-macos since 2.13. + "tokenizers", "tensorflow==2.17.1", "tensorflow-datasets>=4.9.2", "tensorflow-io>=0.37.1", # for tensorflow-2.16. Note that 0.37.0 results in "pure virtual method called". From 03c3fb6bf1699abaa275571e10a3718ddc72b3d3 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Mon, 28 Jul 2025 16:19:02 -0700 Subject: [PATCH 16/57] disable saving of data iterator --- axlearn/experiments/text/gpt/common.py | 3 ++- test-orbax.sh | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index 799c6ed25..4ebc58f6d 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -757,7 +757,8 @@ def config_fn() -> InstantiableConfig: cfg.checkpointer = ckpt_config # Save the data iterator as part of the checkpointing process. - cfg.save_input_iterator = True + # default is false. + # cfg.save_input_iterator = True cfg.summary_writer.write_every_n_steps = min(eval_every_n_steps, 100) cfg.summary_writer.max_queue = 1000 diff --git a/test-orbax.sh b/test-orbax.sh index d4072c5f9..69081613c 100755 --- a/test-orbax.sh +++ b/test-orbax.sh @@ -11,7 +11,8 @@ export INSTANCE_TYPE=${INSTANCE_TYPE:-"tpu-v6e-16"} # Switch to tpu-v6e-256-4 if on scale cluster export MESH_SELECTOR=${MESH:-"tpu-v6e-16"} # Need to use tiktoken when saving data iterator -export CONFIG=${CONFIG:-"fuji-8B-v3-tiktoken-flash-orbax"} +# export CONFIG=${CONFIG:-"fuji-8B-v3-tiktoken-flash-orbax"} +export CONFIG=${CONFIG:-"fuji-7B-v3-flash-orbax"} export PROJECT_ID=$(gcloud config get project) # Example for v6e-256 From 410134dbaecc409016180cb28fbdb76da4b92a7b Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Mon, 28 Jul 2025 19:29:03 -0700 Subject: [PATCH 17/57] use orbaxem --- test-orbax.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test-orbax.sh b/test-orbax.sh index 69081613c..1d5700a5e 100755 --- a/test-orbax.sh +++ b/test-orbax.sh @@ -12,7 +12,7 @@ export INSTANCE_TYPE=${INSTANCE_TYPE:-"tpu-v6e-16"} export MESH_SELECTOR=${MESH:-"tpu-v6e-16"} # Need to use tiktoken when saving data iterator # export CONFIG=${CONFIG:-"fuji-8B-v3-tiktoken-flash-orbax"} -export CONFIG=${CONFIG:-"fuji-7B-v3-flash-orbax"} +export CONFIG=${CONFIG:-"fuji-7B-v3-flash-orbaxem"} export PROJECT_ID=$(gcloud config get project) # Example for v6e-256 From 9051f903eddc0ed6d9d615c1ec99dc0e26c75025 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Mon, 28 Jul 2025 19:42:57 -0700 Subject: [PATCH 18/57] add needed fix for orbax em bigger buffers and newer orbax --- axlearn/cloud/gcp/jobset_utils.py | 3 +++ axlearn/common/compiler_options.py | 7 ++++++- axlearn/experiments/text/gpt/common.py | 2 +- pyproject.toml | 2 +- 4 files changed, 11 insertions(+), 3 deletions(-) diff --git a/axlearn/cloud/gcp/jobset_utils.py b/axlearn/cloud/gcp/jobset_utils.py index 736656ab6..6a435fd79 100644 --- a/axlearn/cloud/gcp/jobset_utils.py +++ b/axlearn/cloud/gcp/jobset_utils.py @@ -451,6 +451,9 @@ def _build_container(self) -> Nested[Any]: if cfg.enable_tpu_ici_resiliency is not None: env_vars["ENABLE_ICI_RESILIENCY"] = str(cfg.enable_tpu_ici_resiliency).lower() + env_vars["TPU_PREMAPPED_BUFFER_SIZE"] = "137438953472" + env_vars["TPU_PREMAPPED_BUFFER_TRANSFER_THRESHOLD_BYTES"] = "137438953472" + resources = {"limits": {"google.com/tpu": system.chips_per_vm}} # Set request memory by host machine type. machine_memory_gi = GCE_MACHINE_TYPE_TO_MEMORY_CHARACTERISTICS.get( diff --git a/axlearn/common/compiler_options.py b/axlearn/common/compiler_options.py index 06239e54b..793989144 100644 --- a/axlearn/common/compiler_options.py +++ b/axlearn/common/compiler_options.py @@ -58,7 +58,12 @@ def default_xla_options( # cause the step time to be double. You should increase this # further if you see "Allocator failed to allocate". A feature # to dynamically allocate may come later: b/380514965 - megascale_grpc_premap_memory_bytes=17179869184, + # Needed for orbax emergency checkpointer + # Needed for decent perf + megascale_grpc_premap_memory_bytes=137438953472, + # needed for restore consistent hash + megascale_jax_offset_launch_id_by_module_name=True, + megascale_jax_use_device_set_based_launch_id=False, # Flag controlling the maximum number of overlapping host offloadings. xla_tpu_host_transfer_overlap_limit=24, # Flag controlling the maximum number of overlapping cross-DCN send/recv. diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index 4ebc58f6d..d24622a50 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -734,7 +734,7 @@ def config_fn() -> InstantiableConfig: max_step=max_step, ) ckpt_config.local_save_policy = config_for_function(every_n_steps_and_last_policy).set( - n=calculated_save_every_n_steps, + n=20, max_step=max_step, ) ckpt_config.local_dir = "/host-tmp/checkpoints" diff --git a/pyproject.toml b/pyproject.toml index 100976d54..83ae9dbf7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -155,7 +155,7 @@ mmau = [ # Orbax checkpointing. orbax = [ "humanize==4.10.0", - "orbax-checkpoint==0.11.15", + "orbax-checkpoint==0.11.20", ] # Audio dependencies. audio = [ From b10878fc0d35307e22f1322122f0cfb9d1e1d702 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Tue, 29 Jul 2025 09:50:44 -0700 Subject: [PATCH 19/57] enable orbax debug logging --- axlearn/experiments/text/gpt/common.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index d24622a50..34ea4262e 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -16,6 +16,9 @@ import jax.numpy as jnp import tensorflow as tf + +# Needed for enabling Orbax debug logging +from absl import logging from jax.sharding import PartitionSpec from axlearn.common import ( @@ -79,6 +82,9 @@ MESH_AXIS_NAMES = ("pipeline", "data", "expert", "fsdp", "seq", "model") +logging.set_verbosity(logging.DEBUG) + + def scaled_hidden_dim(scale: float, *, round_up_to_multiples_of: int = 256) -> FunctionConfigBase: def scale_fn(input_dim: int, *, scale: float, round_up_to_multiples_of: int) -> int: return math.ceil(input_dim * scale / round_up_to_multiples_of) * round_up_to_multiples_of From bc3cbfc3fd84586fd4888e03ff5742044aa9ba36 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Tue, 29 Jul 2025 20:30:03 -0700 Subject: [PATCH 20/57] sort the to be assigned keys and available process indexes --- axlearn/common/checkpointer_orbax_emergency.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/axlearn/common/checkpointer_orbax_emergency.py b/axlearn/common/checkpointer_orbax_emergency.py index f868488e0..a83372a91 100644 --- a/axlearn/common/checkpointer_orbax_emergency.py +++ b/axlearn/common/checkpointer_orbax_emergency.py @@ -345,9 +345,12 @@ def first_run_assign_fn(info: _ProcessInfo): # If there're no assigned slice ids, that means all slices have failed or we're in the # very first run. In that case, first_run_assign_fn will be used. if already_assigned_slice_ids: - to_be_assigned_slice_ids = set(range(num_slices)) - already_assigned_slice_ids - assert len(to_be_assigned_slice_ids) == len(failed_slices_new_ids) - for k, new_id in zip(failed_slices_new_ids.keys(), to_be_assigned_slice_ids): + to_be_assigned_slice_ids = sorted( + list(set(range(num_slices)) - already_assigned_slice_ids) + ) + failed_slice_keys = sorted(list(failed_slices_new_ids.keys())) + assert len(to_be_assigned_slice_ids) == len(failed_slice_keys) + for k, new_id in zip(failed_slice_keys, to_be_assigned_slice_ids): failed_slices_new_ids[k] = new_id def assign_fn(info: _ProcessInfo): From f8000d7ea17681b02c490b01420ec4d52b9c0c30 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Tue, 29 Jul 2025 21:49:47 -0700 Subject: [PATCH 21/57] sort proc_infos as well --- axlearn/common/checkpointer_orbax_emergency.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/axlearn/common/checkpointer_orbax_emergency.py b/axlearn/common/checkpointer_orbax_emergency.py index a83372a91..e96e1f48b 100644 --- a/axlearn/common/checkpointer_orbax_emergency.py +++ b/axlearn/common/checkpointer_orbax_emergency.py @@ -334,6 +334,10 @@ def first_run_assign_fn(info: _ProcessInfo): for k, data in ids: info = _ProcessInfo.from_string(data, key=k, num_proc_per_slice=num_proc_per_slice) proc_infos.append(info) + # Sort by current proc id to ensure the order is deterministic. + proc_infos.sort(key=lambda info: info.cur_proc_id) + + for info in proc_infos: if info.inv_proc_id == -1: failed_slices_new_ids[info.cur_slice_id] = -1 @@ -372,6 +376,10 @@ def assign_fn(info: _ProcessInfo): for key, data in ids: info = _ProcessInfo.from_string(data, key=key) proc_infos.append(info) + # Sort by current proc id to ensure the order is deterministic. + proc_infos.sort(key=lambda info: info.cur_proc_id) + + for info in proc_infos: assigned_ids.add(info.inv_proc_id) # If there're no assigned ids, that means all slices have failed or we're in the From e752a128176d56d3b6bacdaa87aee3f18a05bcff Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Tue, 29 Jul 2025 23:47:10 -0700 Subject: [PATCH 22/57] Revert "sort proc_infos as well" This reverts commit f8000d7ea17681b02c490b01420ec4d52b9c0c30. --- axlearn/common/checkpointer_orbax_emergency.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/axlearn/common/checkpointer_orbax_emergency.py b/axlearn/common/checkpointer_orbax_emergency.py index e96e1f48b..a83372a91 100644 --- a/axlearn/common/checkpointer_orbax_emergency.py +++ b/axlearn/common/checkpointer_orbax_emergency.py @@ -334,10 +334,6 @@ def first_run_assign_fn(info: _ProcessInfo): for k, data in ids: info = _ProcessInfo.from_string(data, key=k, num_proc_per_slice=num_proc_per_slice) proc_infos.append(info) - # Sort by current proc id to ensure the order is deterministic. - proc_infos.sort(key=lambda info: info.cur_proc_id) - - for info in proc_infos: if info.inv_proc_id == -1: failed_slices_new_ids[info.cur_slice_id] = -1 @@ -376,10 +372,6 @@ def assign_fn(info: _ProcessInfo): for key, data in ids: info = _ProcessInfo.from_string(data, key=key) proc_infos.append(info) - # Sort by current proc id to ensure the order is deterministic. - proc_infos.sort(key=lambda info: info.cur_proc_id) - - for info in proc_infos: assigned_ids.add(info.inv_proc_id) # If there're no assigned ids, that means all slices have failed or we're in the From c7bdd39284cbd7d84b09681d928aa240f6ac69bd Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Tue, 29 Jul 2025 23:47:17 -0700 Subject: [PATCH 23/57] Revert "sort the to be assigned keys and available process indexes" This reverts commit bc3cbfc3fd84586fd4888e03ff5742044aa9ba36. --- axlearn/common/checkpointer_orbax_emergency.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/axlearn/common/checkpointer_orbax_emergency.py b/axlearn/common/checkpointer_orbax_emergency.py index a83372a91..f868488e0 100644 --- a/axlearn/common/checkpointer_orbax_emergency.py +++ b/axlearn/common/checkpointer_orbax_emergency.py @@ -345,12 +345,9 @@ def first_run_assign_fn(info: _ProcessInfo): # If there're no assigned slice ids, that means all slices have failed or we're in the # very first run. In that case, first_run_assign_fn will be used. if already_assigned_slice_ids: - to_be_assigned_slice_ids = sorted( - list(set(range(num_slices)) - already_assigned_slice_ids) - ) - failed_slice_keys = sorted(list(failed_slices_new_ids.keys())) - assert len(to_be_assigned_slice_ids) == len(failed_slice_keys) - for k, new_id in zip(failed_slice_keys, to_be_assigned_slice_ids): + to_be_assigned_slice_ids = set(range(num_slices)) - already_assigned_slice_ids + assert len(to_be_assigned_slice_ids) == len(failed_slices_new_ids) + for k, new_id in zip(failed_slices_new_ids.keys(), to_be_assigned_slice_ids): failed_slices_new_ids[k] = new_id def assign_fn(info: _ProcessInfo): From 13cb8ac8dd65bad03fd4a2722a2843d9bc9e6f93 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Wed, 30 Jul 2025 00:04:54 -0700 Subject: [PATCH 24/57] gemini fix? --- .../common/checkpointer_orbax_emergency.py | 33 ++++++++++++------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/axlearn/common/checkpointer_orbax_emergency.py b/axlearn/common/checkpointer_orbax_emergency.py index f868488e0..a1373c90c 100644 --- a/axlearn/common/checkpointer_orbax_emergency.py +++ b/axlearn/common/checkpointer_orbax_emergency.py @@ -345,19 +345,30 @@ def first_run_assign_fn(info: _ProcessInfo): # If there're no assigned slice ids, that means all slices have failed or we're in the # very first run. In that case, first_run_assign_fn will be used. if already_assigned_slice_ids: - to_be_assigned_slice_ids = set(range(num_slices)) - already_assigned_slice_ids - assert len(to_be_assigned_slice_ids) == len(failed_slices_new_ids) - for k, new_id in zip(failed_slices_new_ids.keys(), to_be_assigned_slice_ids): - failed_slices_new_ids[k] = new_id + # Create a flat list of all available inv_proc_ids from the dead slices. + to_be_assigned_inv_proc_ids = [] + all_slice_ids = set(range(num_slices)) + dead_slice_ids = all_slice_ids - already_assigned_slice_ids + for slice_id in sorted(list(dead_slice_ids)): + for i in range(num_proc_per_slice): + to_be_assigned_inv_proc_ids.append(slice_id * num_proc_per_slice + i) + + # Create a flat list of all newcomer processes that need an ID. + newcomer_procs = [p for p in proc_infos if p.inv_proc_id == -1] + # Sort newcomers by their temporary physical ID to ensure determinism. + newcomer_procs.sort(key=lambda p: p.cur_proc_id) + + assert len(to_be_assigned_inv_proc_ids) == len(newcomer_procs) + + # Assign the sorted available IDs to the sorted newcomers directly. + for proc_info, inv_proc_id in zip(newcomer_procs, to_be_assigned_inv_proc_ids): + proc_info.inv_proc_id = inv_proc_id def assign_fn(info: _ProcessInfo): - proc_id = info.inv_proc_id - if (new_slice_id := failed_slices_new_ids.get(info.cur_slice_id)) is not None: - proc_id = ( - new_slice_id * num_proc_per_slice - + info.cur_proc_id % num_proc_per_slice - ) - info.inv_proc_id = proc_id + # This function is now only responsible for applying the pre-computed + # assignments to the processes. + print(f"processInfo={info}") + pass inv_id_assign_fn = assign_fn From ce3d7ad20e0ae3078c76574c36f3c48beae467f6 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Wed, 30 Jul 2025 21:25:26 +0800 Subject: [PATCH 25/57] fail fast --- axlearn/common/trainer.py | 66 +++++++++++++++++++-------------------- 1 file changed, 32 insertions(+), 34 deletions(-) diff --git a/axlearn/common/trainer.py b/axlearn/common/trainer.py index 4d501f64a..1c6c710aa 100644 --- a/axlearn/common/trainer.py +++ b/axlearn/common/trainer.py @@ -927,42 +927,40 @@ def restore_checkpoint(self, restore_step: Optional[int] = None) -> Optional[int **ckpt_state_spec, input_iter=iter(self.input.dataset()) ) restore_input_iter = cfg.save_input_iterator - try: - # Try to restore with `input_iter`. - step, ckpt_state = self.checkpointer.restore( - step=restore_step, - state=( - ckpt_state_spec_with_input_iter if restore_input_iter else ckpt_state_spec - ), - ) - if step is not None: - self.vlog( - 0, - "Restored checkpoint at %s with restore_input_iter=%s", - step, - restore_input_iter, - ) - except ValueError as e: - logging.warning( - "Attempt to restore checkpoint with restore_input_iter=%s failed: %s", + # try: + # Try to restore with `input_iter`. + step, ckpt_state = self.checkpointer.restore( + step=restore_step, + state=(ckpt_state_spec_with_input_iter if restore_input_iter else ckpt_state_spec), + ) + if step is not None: + self.vlog( + 0, + "Restored checkpoint at %s with restore_input_iter=%s", + step, restore_input_iter, - e, - ) - # Restore with a different restore_input_iter setting. - restore_input_iter = not restore_input_iter - step, ckpt_state = self.checkpointer.restore( - step=restore_step, - state=( - ckpt_state_spec_with_input_iter if restore_input_iter else ckpt_state_spec - ), ) - if step is not None: - self.vlog( - 0, - "Restored checkpoint at %s with restore_input_iter=%s", - step, - restore_input_iter, - ) + # except ValueError as e: + # logging.warning( + # "Attempt to restore checkpoint with restore_input_iter=%s failed: %s", + # restore_input_iter, + # e, + # ) + # # Restore with a different restore_input_iter setting. + # restore_input_iter = not restore_input_iter + # step, ckpt_state = self.checkpointer.restore( + # step=restore_step, + # state=( + # ckpt_state_spec_with_input_iter if restore_input_iter else ckpt_state_spec + # ), + # ) + # if step is not None: + # self.vlog( + # 0, + # "Restored checkpoint at %s with restore_input_iter=%s", + # step, + # restore_input_iter, + # ) if step is not None: self._step = step self._trainer_state = TrainerState( From 6223a3f4e0af66128957fd796c37318e333597ba Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Wed, 30 Jul 2025 09:34:21 -0700 Subject: [PATCH 26/57] Revert "gemini fix?" This reverts commit 13cb8ac8dd65bad03fd4a2722a2843d9bc9e6f93. --- .../common/checkpointer_orbax_emergency.py | 33 +++++++------------ 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/axlearn/common/checkpointer_orbax_emergency.py b/axlearn/common/checkpointer_orbax_emergency.py index a1373c90c..f868488e0 100644 --- a/axlearn/common/checkpointer_orbax_emergency.py +++ b/axlearn/common/checkpointer_orbax_emergency.py @@ -345,30 +345,19 @@ def first_run_assign_fn(info: _ProcessInfo): # If there're no assigned slice ids, that means all slices have failed or we're in the # very first run. In that case, first_run_assign_fn will be used. if already_assigned_slice_ids: - # Create a flat list of all available inv_proc_ids from the dead slices. - to_be_assigned_inv_proc_ids = [] - all_slice_ids = set(range(num_slices)) - dead_slice_ids = all_slice_ids - already_assigned_slice_ids - for slice_id in sorted(list(dead_slice_ids)): - for i in range(num_proc_per_slice): - to_be_assigned_inv_proc_ids.append(slice_id * num_proc_per_slice + i) - - # Create a flat list of all newcomer processes that need an ID. - newcomer_procs = [p for p in proc_infos if p.inv_proc_id == -1] - # Sort newcomers by their temporary physical ID to ensure determinism. - newcomer_procs.sort(key=lambda p: p.cur_proc_id) - - assert len(to_be_assigned_inv_proc_ids) == len(newcomer_procs) - - # Assign the sorted available IDs to the sorted newcomers directly. - for proc_info, inv_proc_id in zip(newcomer_procs, to_be_assigned_inv_proc_ids): - proc_info.inv_proc_id = inv_proc_id + to_be_assigned_slice_ids = set(range(num_slices)) - already_assigned_slice_ids + assert len(to_be_assigned_slice_ids) == len(failed_slices_new_ids) + for k, new_id in zip(failed_slices_new_ids.keys(), to_be_assigned_slice_ids): + failed_slices_new_ids[k] = new_id def assign_fn(info: _ProcessInfo): - # This function is now only responsible for applying the pre-computed - # assignments to the processes. - print(f"processInfo={info}") - pass + proc_id = info.inv_proc_id + if (new_slice_id := failed_slices_new_ids.get(info.cur_slice_id)) is not None: + proc_id = ( + new_slice_id * num_proc_per_slice + + info.cur_proc_id % num_proc_per_slice + ) + info.inv_proc_id = proc_id inv_id_assign_fn = assign_fn From 1fcb62bfabc20f603977be0eab9af04fa91a0e94 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Wed, 30 Jul 2025 14:02:37 -0700 Subject: [PATCH 27/57] switch orbax with Jun's patch --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 83ae9dbf7..0a9bce9ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -155,7 +155,8 @@ mmau = [ # Orbax checkpointing. orbax = [ "humanize==4.10.0", - "orbax-checkpoint==0.11.20", + # "orbax-checkpoint==0.11.20", + "orbax-checkpoint @ git+https://github.com/samos123/orbax.git@v0.11.20-em#subdirectory=checkpoint" ] # Audio dependencies. audio = [ From ad4a4986300dd6aac315bf7f8e90f88671743b1b Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Wed, 30 Jul 2025 18:22:35 -0700 Subject: [PATCH 28/57] add large scale config --- axlearn/experiments/text/gpt/fuji.py | 8 ++++++-- test-orbax.sh | 4 +++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index c4e0eac14..40fb010b7 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -637,6 +637,10 @@ def get_trainer_kwargs( ), ) elif model_size == "70B": + # pylint: disable=import-outside-toplevel + import jax + + gbs = len(jax.devices()) trainer_kwargs = dict( model_kwargs=dict( num_layers=80, @@ -652,7 +656,7 @@ def get_trainer_kwargs( ), learner_kwargs=dict(peak_lr=1.5e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, - train_batch_size=train_batch_size, + train_batch_size=gbs, max_step=max_step, mesh_shape=mesh_shape_from_axes(fsdp=-1), mesh_rules=( @@ -714,7 +718,7 @@ def get_trainer_kwargs( ChainConfigModifier.default_config().set( config_modifiers=[ MeshShapeModifier.default_config().set( - mesh_shape=mesh_shape_from_axes(data=-1, fsdp=256) + mesh_shape=mesh_shape_from_axes(data=-1, fsdp=128) ), RematSpecModifier.default_config().set( remat_policies={ diff --git a/test-orbax.sh b/test-orbax.sh index 1d5700a5e..aa7eea91b 100755 --- a/test-orbax.sh +++ b/test-orbax.sh @@ -9,7 +9,7 @@ export GKE_CLUSTER=$(axlearn gcp config | grep gke_cluster | awk '{ print $3 }' # Switch to tpu-v6e-256 if on scale cluster export INSTANCE_TYPE=${INSTANCE_TYPE:-"tpu-v6e-16"} # Switch to tpu-v6e-256-4 if on scale cluster -export MESH_SELECTOR=${MESH:-"tpu-v6e-16"} +export MESH_SELECTOR=${MESH_SELECTOR:-"tpu-v6e-16"} # Need to use tiktoken when saving data iterator # export CONFIG=${CONFIG:-"fuji-8B-v3-tiktoken-flash-orbax"} export CONFIG=${CONFIG:-"fuji-7B-v3-flash-orbaxem"} @@ -38,6 +38,8 @@ if [[ "$CONFIG" == *"orbaxem"* ]]; then echo "Running with Orbax emergency checkpointer." axlearn gcp launch run --cluster=$GKE_CLUSTER \ --runner_name gke_tpu_single \ + --queue=multislice-queue \ + --priority_class=very-high \ --name=$JOBSET_NAME \ --instance_type=${INSTANCE_TYPE} \ --host_mount_spec=name=tmp,host_path=/tmp,mount_path=/host-tmp \ From 3be33cd398d98e3c69bf0df107cc0cb236580a6f Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Wed, 30 Jul 2025 23:05:09 -0700 Subject: [PATCH 29/57] enable BlockingRecreate --- axlearn/cloud/gcp/job.py | 4 +++- test-orbax.sh | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/axlearn/cloud/gcp/job.py b/axlearn/cloud/gcp/job.py index 92cb8d045..e1a1c6d41 100644 --- a/axlearn/cloud/gcp/job.py +++ b/axlearn/cloud/gcp/job.py @@ -134,7 +134,9 @@ def _build_jobset(self) -> Nested[Any]: return dict( metadata=dict(name=cfg.name, annotations=annotations), spec=dict( - failurePolicy=dict(maxRestarts=cfg.max_tries - 1), + failurePolicy=dict( + maxRestarts=cfg.max_tries - 1, restartStrategy="BlockingRecreate" + ), replicatedJobs=self._builder(), ), ) diff --git a/test-orbax.sh b/test-orbax.sh index aa7eea91b..f96d2c1dc 100755 --- a/test-orbax.sh +++ b/test-orbax.sh @@ -51,7 +51,7 @@ if [[ "$CONFIG" == *"orbaxem"* ]]; then --init_module=axlearn.common.checkpointer_orbax_emergency:local_ckpt_dir=/host-tmp/checkpoints \ --module=text.gpt.c4_trainer \ --config=${CONFIG} \ - --trainer_dir=gs://${PROJECT_ID}-axlearn/${JOBSET_NAME}-nr-${NUM_REPLICAS}/ \ + --trainer_dir=gs://${PROJECT_ID}-axlearn/${JOBSET_NAME} \ --data_dir=gs://axlearn-public/tensorflow_datasets \ --jax_backend=tpu \ --mesh_selector=${MESH_SELECTOR} \ From c4f20f216d8b9572d04bc656861a2f047d97f2d5 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Thu, 31 Jul 2025 10:50:48 -0700 Subject: [PATCH 30/57] disable BlockingRecreate on jobset --- axlearn/cloud/gcp/job.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/axlearn/cloud/gcp/job.py b/axlearn/cloud/gcp/job.py index e1a1c6d41..e7b8e56f5 100644 --- a/axlearn/cloud/gcp/job.py +++ b/axlearn/cloud/gcp/job.py @@ -135,7 +135,9 @@ def _build_jobset(self) -> Nested[Any]: metadata=dict(name=cfg.name, annotations=annotations), spec=dict( failurePolicy=dict( - maxRestarts=cfg.max_tries - 1, restartStrategy="BlockingRecreate" + # maxRestarts=cfg.max_tries - 1, restartStrategy="BlockingRecreate" + maxRestarts=cfg.max_tries + - 1, ), replicatedJobs=self._builder(), ), From 0d9513c927f7bd86a013cc1da9f01f818123d067 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Sun, 3 Aug 2025 10:57:51 -0700 Subject: [PATCH 31/57] switch cluster --- test-orbax.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test-orbax.sh b/test-orbax.sh index f96d2c1dc..e6b7867b8 100755 --- a/test-orbax.sh +++ b/test-orbax.sh @@ -14,6 +14,7 @@ export MESH_SELECTOR=${MESH_SELECTOR:-"tpu-v6e-16"} # export CONFIG=${CONFIG:-"fuji-8B-v3-tiktoken-flash-orbax"} export CONFIG=${CONFIG:-"fuji-7B-v3-flash-orbaxem"} export PROJECT_ID=$(gcloud config get project) +export TRAINER_DIR=gs://tpu-prod-env-multipod-use4 # Example for v6e-256 # MESH_SELECTOR=tpu-v6e-256-4 INSTANCE_TYPE=tpu-v6e-256 ./test-orbax.sh @@ -51,7 +52,7 @@ if [[ "$CONFIG" == *"orbaxem"* ]]; then --init_module=axlearn.common.checkpointer_orbax_emergency:local_ckpt_dir=/host-tmp/checkpoints \ --module=text.gpt.c4_trainer \ --config=${CONFIG} \ - --trainer_dir=gs://${PROJECT_ID}-axlearn/${JOBSET_NAME} \ + --trainer_dir=${TRAINER_DIR}/${JOBSET_NAME} \ --data_dir=gs://axlearn-public/tensorflow_datasets \ --jax_backend=tpu \ --mesh_selector=${MESH_SELECTOR} \ From c4f87662b8d438451c0afbe700dc226441dbe122 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Sun, 3 Aug 2025 11:45:02 -0700 Subject: [PATCH 32/57] remove debug logging --- axlearn/experiments/text/gpt/common.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index 34ea4262e..d24622a50 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -16,9 +16,6 @@ import jax.numpy as jnp import tensorflow as tf - -# Needed for enabling Orbax debug logging -from absl import logging from jax.sharding import PartitionSpec from axlearn.common import ( @@ -82,9 +79,6 @@ MESH_AXIS_NAMES = ("pipeline", "data", "expert", "fsdp", "seq", "model") -logging.set_verbosity(logging.DEBUG) - - def scaled_hidden_dim(scale: float, *, round_up_to_multiples_of: int = 256) -> FunctionConfigBase: def scale_fn(input_dim: int, *, scale: float, round_up_to_multiples_of: int) -> int: return math.ceil(input_dim * scale / round_up_to_multiples_of) * round_up_to_multiples_of From 087028535add3e2d5c7a4014d04b252e8d4e7845 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Sun, 3 Aug 2025 15:46:48 -0700 Subject: [PATCH 33/57] add script to force delete pods --- force_delete_terminating_pods.sh | 59 ++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100755 force_delete_terminating_pods.sh diff --git a/force_delete_terminating_pods.sh b/force_delete_terminating_pods.sh new file mode 100755 index 000000000..dee311d8e --- /dev/null +++ b/force_delete_terminating_pods.sh @@ -0,0 +1,59 @@ +#!/bin/bash +# +# This script finds pods that are stuck in the 'Terminating' state for more +# than a specified duration and forcibly deletes them. + +# --- Configuration --- +# The maximum duration (in seconds) a pod is allowed to be in the Terminating state. +STUCK_DURATION_SECONDS=1200 + +# How often (in seconds) the script should check for stuck pods. +CHECK_INTERVAL_SECONDS=60 +# --- End Configuration --- + +echo "Starting pod termination monitor..." +echo "Stuck Threshold: ${STUCK_DURATION_SECONDS}s | Check Interval: ${CHECK_INTERVAL_SECONDS}s" + +while true; do + echo "$(date '+%Y-%m-%d %H:%M:%S') - Checking for stuck pods..." + + # Get all pods with a deletion timestamp in JSON format. + # The 'deletionTimestamp' field is only present for pods that are being terminated. + stuck_pods_json=$(kubectl get pods -o json | jq -c '.items[] | select(.metadata.deletionTimestamp)') + + if [ -z "$stuck_pods_json" ]; then + echo "No pods are currently in a terminating state." + else + echo "$stuck_pods_json" | while read -r pod_json; do + # Extract details from the pod's JSON data + pod_name=$(echo "$pod_json" | jq -r '.metadata.name') + pod_namespace=$(echo "$pod_json" | jq -r '.metadata.namespace') + deletion_timestamp_str=$(echo "$pod_json" | jq -r '.metadata.deletionTimestamp') + + # Convert the RFC3339 timestamp to a Unix epoch timestamp + # Works on both GNU and BSD (macOS) date commands. + if date --version >/dev/null 2>&1; then # GNU date + deletion_ts=$(date -d "$deletion_timestamp_str" +%s) + else # BSD date + deletion_ts=$(date -jf "%Y-%m-%dT%H:%M:%SZ" "$deletion_timestamp_str" +%s) + fi + + # Get the current time as a Unix epoch timestamp + now_ts=$(date +%s) + + # Calculate how long the pod has been terminating + duration=$((now_ts - deletion_ts)) + + echo " - Checking pod '$pod_name' in namespace '$pod_namespace' (terminating for ${duration}s)" + + if [ "$duration" -gt "$STUCK_DURATION_SECONDS" ]; then + echo " -> STUCK! Pod '$pod_name' has been terminating for ${duration}s. Forcing deletion." + # Force delete the pod. The --grace-period=0 is crucial. + kubectl delete pod "$pod_name" -n "$pod_namespace" --force --grace-period=0 + fi + done + fi + + echo "Check complete. Sleeping for ${CHECK_INTERVAL_SECONDS} seconds..." + sleep "$CHECK_INTERVAL_SECONDS" +done From 0a079a1223aba1c7bf90aff4e2fcb6178761fc81 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Sun, 3 Aug 2025 15:47:00 -0700 Subject: [PATCH 34/57] use BlockingRecreate --- axlearn/cloud/gcp/job.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/axlearn/cloud/gcp/job.py b/axlearn/cloud/gcp/job.py index e7b8e56f5..5099aa3c3 100644 --- a/axlearn/cloud/gcp/job.py +++ b/axlearn/cloud/gcp/job.py @@ -135,9 +135,9 @@ def _build_jobset(self) -> Nested[Any]: metadata=dict(name=cfg.name, annotations=annotations), spec=dict( failurePolicy=dict( - # maxRestarts=cfg.max_tries - 1, restartStrategy="BlockingRecreate" - maxRestarts=cfg.max_tries - - 1, + maxRestarts=cfg.max_tries - 1, + restartStrategy="BlockingRecreate" + # maxRestarts=cfg.max_tries - 1, ), replicatedJobs=self._builder(), ), From 227c94bcdba07eaf8808b3e4f42923c43411b560 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Sun, 3 Aug 2025 22:49:59 -0700 Subject: [PATCH 35/57] add goodput recorder --- test-orbax.sh | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test-orbax.sh b/test-orbax.sh index e6b7867b8..be3b4b5ec 100755 --- a/test-orbax.sh +++ b/test-orbax.sh @@ -57,6 +57,11 @@ if [[ "$CONFIG" == *"orbaxem"* ]]; then --jax_backend=tpu \ --mesh_selector=${MESH_SELECTOR} \ --initialization_timeout=1200 \ + --recorder_type=axlearn.cloud.gcp.measurement:goodput \ + --recorder_spec=name=goodput_${JOBSET_NAME} \ + --recorder_spec=upload_dir=${TRAINER_DIR}/summaries \ + --recorder_spec=upload_interval=30 \ + --recorder_spec=step_deviation_interval_seconds=30 \ --trace_at_steps=29,59,89,119,149,179,209,239,269,299,329,359,389,419,449,479,509,539,569,599,629,659,689,719 else From 0c2ce00e3598a2a7b7dea69e806f201b4af4f123 Mon Sep 17 00:00:00 2001 From: Dipannita Shaw Date: Fri, 25 Jul 2025 16:47:22 +0000 Subject: [PATCH 36/57] Integrate AXLearn with latest Goodput package --- axlearn/cloud/gcp/measurement.py | 202 ++++++---- axlearn/cloud/gcp/measurement_test.py | 508 +++++++++++++++++--------- axlearn/common/launch_trainer.py | 18 +- axlearn/common/launch_trainer_main.py | 1 - axlearn/common/measurement.py | 55 ++- axlearn/common/measurement_test.py | 52 ++- axlearn/common/trainer.py | 340 +++++++++-------- docs/05-Goodput-Monitoring.md | 108 ++++-- pyproject.toml | 2 +- 9 files changed, 784 insertions(+), 502 deletions(-) diff --git a/axlearn/cloud/gcp/measurement.py b/axlearn/cloud/gcp/measurement.py index 0d4ce0069..0eb226e6f 100644 --- a/axlearn/cloud/gcp/measurement.py +++ b/axlearn/cloud/gcp/measurement.py @@ -2,6 +2,9 @@ """Measurement utils for GCP. + For detailed documentation and advanced usage, please refer to: + axlearn/docs/05-Goodput-Monitoring.md + Example: # Enable Goodput when launching an AXLearn training job @@ -13,10 +16,14 @@ --recorder_spec=name=my-run-with-goodput \ --recorder_spec=upload_dir=my-output-directory/summaries \ --recorder_spec=upload_interval=30 \ - --recorder_spec=step_deviation_interval_seconds=30 + --recorder_spec=rolling_window_size=86400,604800 """ +import contextlib +import os +from typing import Optional, Sequence + import jax from absl import flags, logging from ml_goodput_measurement import goodput @@ -38,13 +45,19 @@ class Config(measurement.Recorder.Config): Attributes: upload_dir: Directory to store metrics for the monitor. upload_interval: Time interval (seconds) for monitoring uploads. - step_deviation_interval_seconds: Time interval (seconds) for step deviation metrics - uploads. -1 to disable step deviation uploads. + See "How to Monitor Cumulative Goodput Metrics" in + docs/05-Goodput-Monitoring.md for more details. + rolling_window_size: A sequence of integers defining the rolling window sizes in + seconds. + See "How to Monitor Rolling Window Goodput Metrics" in + docs/05-Goodput-Monitoring.md for more details. + jax_backend: Jax backend type to infer Pathways environment. """ upload_dir: Required[str] = REQUIRED upload_interval: Required[int] = REQUIRED - step_deviation_interval_seconds: int = 30 # Default to 30 seconds + rolling_window_size: Sequence[int] = [] + jax_backend: Optional[str] = None @classmethod def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder": @@ -53,68 +66,78 @@ def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder": `fv.recorder_spec` will be interpreted as a list of `key=value` pairs; config names corresponding to keys will be set to the corresponding values. A GoodputRecorder can additionally take in following Tensorboard configs in the recorder_spec: - - upload_dir: The directory to write Tensorboard data to. - - upload_interval: The time interval in seconds at which to query and upload data - to Tensorboard. - - step_deviation_interval_seconds: Time interval (seconds) for step deviation metrics - uploads. Set to less than or equal to 0 to disable step deviation uploads. + - upload_dir: The directory to write Tensorboard data to. + - upload_interval: The time interval in seconds at which to query and upload data + to Tensorboard. + - rolling_window_size: Comma-separated list of integers representing rolling window + sizes in seconds. + - jax_backend: The type of jax backend. """ cfg: measurement.Recorder.Config = cls.default_config() - cfg = maybe_set_config(cfg, **parse_kv_flags(fv.recorder_spec, delimiter="=")) - return cfg.instantiate() + parsed_flags = parse_kv_flags(fv.recorder_spec, delimiter="=") + if "upload_interval" in parsed_flags: + parsed_flags["upload_interval"] = int(parsed_flags["upload_interval"]) + if "rolling_window_size" in parsed_flags and isinstance( + parsed_flags["rolling_window_size"], str + ): + parsed_flags["rolling_window_size"] = [ + int(x) for x in parsed_flags["rolling_window_size"].split(",") + ] + return maybe_set_config(cfg, **parsed_flags).instantiate() def __init__(self, cfg): super().__init__(cfg) - cfg: GoodputRecorder.Config = self.config - self._recorder = None - self._monitor = None - - def record(self, event: measurement.Event, *args, **kwargs): - # Lazily instantiate the recorder. This avoids invoking jax before setup is complete. + self._recorder: Optional[goodput.GoodputRecorder] = None + self._monitor: Optional[goodput_monitoring.GoodputMonitor] = None + self._rolling_window_monitor: Optional[goodput_monitoring.GoodputMonitor] = None + self._job_name = cfg.name + self._logger_name = f"goodput_logger_{cfg.name}" + + @contextlib.contextmanager + def record_event(self, event: measurement.Event, *args, **kwargs): + """Records a goodput event using a context manager.""" + # Lazily instantiate the recorder if it hasn't been already. if self._recorder is None: - cfg: GoodputRecorder.Config = self.config + if jax.process_index() == 0: + logging.info("Lazily instantiating goodput recorder.") self._recorder = goodput.GoodputRecorder( - job_name=cfg.name, - logger_name=f"goodput_logger_{cfg.name}", + job_name=self._job_name, + logger_name=self._logger_name, logging_enabled=(jax.process_index() == 0), ) - if event == measurement.Event.START_JOB: - self._recorder.record_job_start_time(*args, **kwargs) - elif event == measurement.Event.END_JOB: - self._recorder.record_job_end_time(*args, **kwargs) - elif event == measurement.Event.START_STEP: - self._recorder.record_step_start_time(*args, **kwargs) - elif event == measurement.Event.START_ACCELERATOR_INIT: - self._recorder.record_tpu_init_start_time(*args, **kwargs) - elif event == measurement.Event.END_ACCELERATOR_INIT: - self._recorder.record_tpu_init_end_time(*args, **kwargs) - elif event == measurement.Event.START_TRAINING_PREPARATION: - self._recorder.record_training_preparation_start_time(*args, **kwargs) - elif event == measurement.Event.END_TRAINING_PREPARATION: - self._recorder.record_training_preparation_end_time(*args, **kwargs) - elif event == measurement.Event.START_DATA_LOADING: - self._recorder.record_data_loading_start_time(*args, **kwargs) - elif event == measurement.Event.END_DATA_LOADING: - self._recorder.record_data_loading_end_time(*args, **kwargs) - elif event == measurement.Event.START_CUSTOM_BADPUT_EVENT: - self._recorder.record_custom_badput_event_start_time(*args, **kwargs) - elif event == measurement.Event.END_CUSTOM_BADPUT_EVENT: - self._recorder.record_custom_badput_event_end_time(*args, **kwargs) - else: - logging.log_first_n( - logging.WARNING, - "Ignoring unknown event %s", - 1, - event, + start_method_name = f"record_{event.value}_start_time" + end_method_name = f"record_{event.value}_end_time" + + record_event_start = getattr(self._recorder, start_method_name, None) + record_event_end = getattr(self._recorder, end_method_name, None) + + try: + if record_event_start: + record_event_start(*args, **kwargs) + except RuntimeError as e: + logging.warning( + "Failed to record start of event %s. Error: %s", event.value, e, exc_info=True ) - def start_monitoring(self, *args, **kwargs): - """Starts Monitoring of Goodput. + try: + yield + finally: + try: + if record_event_end: + record_event_end(*args, **kwargs) + except RuntimeError as e: + logging.warning( + "Failed to record end of event %s. Error: %s", event.value, e, exc_info=True + ) + + @contextlib.contextmanager + def _maybe_monitor_goodput(self, *args, **kwargs): + """Monitor cumulative goodput if enabled. Instantiate ml-goodput-measurement's GoodputMonitor to asynchronously calculate - Goodput and Badput at the upload_interval and upload to the specified TensorBoard - directory. + Goodput, Badput, Step & Disruption Information at the upload_interval to the + specified TensorBoard directory and Google Cloud Monitoring. Note: This function requires initialization of distributed JAX before it is called. If there are internal GCP errors from querying and uploading data, these will be logged without affecting the workload. GoodputMonitor logs will provide further @@ -123,33 +146,68 @@ def start_monitoring(self, *args, **kwargs): Default behavior is to push metrics to Google Cloud Monitoring. This behavior can be overridden by configuring `goodput_monitoring.GCPOptions` """ - cfg: GoodputRecorder.Config = self.config - include_step_deviation = True - if jax.process_index() == 0: + if jax.process_index() != 0: + yield + return + try: if self._monitor is None: - if int(cfg.step_deviation_interval_seconds) <= 0: - include_step_deviation = False - - gcp_options = goodput_monitoring.GCPOptions( - enable_gcp_goodput_metrics=True, - enable_gcp_step_deviation_metrics=include_step_deviation, - ) self._monitor = goodput_monitoring.GoodputMonitor( - job_name=cfg.name, - logger_name=f"goodput_logger_{cfg.name}", - tensorboard_dir=cfg.upload_dir, - upload_interval=int(cfg.upload_interval), + job_name=self._job_name, + logger_name=self._logger_name, + tensorboard_dir=self.config.upload_dir, + upload_interval=self.config.upload_interval, monitoring_enabled=True, + pathway_enabled=self.config.jax_backend == "proxy", include_badput_breakdown=True, - include_step_deviation=include_step_deviation, - step_deviation_interval_seconds=int(cfg.step_deviation_interval_seconds), - gcp_options=gcp_options, ) self._monitor.start_goodput_uploader(*args, **kwargs) logging.info("Started Goodput upload to Tensorboard & GCM in the background!") - if include_step_deviation: - self._monitor.start_step_deviation_uploader(*args, **kwargs) + yield + finally: + if self._monitor: + self._monitor.stop_goodput_uploader() + logging.info("Flushed final metrics and safe exited from Goodput monitoring.") + + @contextlib.contextmanager + def _maybe_monitor_rolling_window_goodput(self): + """Monitor rolling window goodput if enabled.""" + if not self.config.rolling_window_size or jax.process_index() != 0: + yield + return + try: + if self._rolling_window_monitor is None: + rolling_window_tensorboard_dir = os.path.join( + self.config.upload_dir, f"rolling_window_{self.config.name}" + ) + self._rolling_window_monitor = goodput_monitoring.GoodputMonitor( + job_name=self._job_name, + logger_name=self._logger_name, + tensorboard_dir=rolling_window_tensorboard_dir, + upload_interval=self.config.upload_interval, + monitoring_enabled=True, + pathway_enabled=self.config.jax_backend == "proxy", + include_badput_breakdown=True, + ) + self._rolling_window_monitor.start_rolling_window_goodput_uploader( + self.config.rolling_window_size + ) + logging.info("Started Rolling Window Goodput monitoring in the background!") + yield + finally: + if self._rolling_window_monitor: + self._rolling_window_monitor.stop_rolling_window_goodput_uploader() logging.info( - "Started Step Deviation upload to Tensorboard & GCM in the background!" + "Flushed final metrics and safe exited from Rolling Window Goodput monitoring." ) + + def maybe_monitor_all_goodput(self): + goodput_monitor_manager = self._maybe_monitor_goodput() + rolling_goodput_monitor_manager = self._maybe_monitor_rolling_window_goodput() + + @contextlib.contextmanager + def monitor_goodput(): + with goodput_monitor_manager, rolling_goodput_monitor_manager: + yield + + return monitor_goodput() diff --git a/axlearn/cloud/gcp/measurement_test.py b/axlearn/cloud/gcp/measurement_test.py index e14fc16c4..e944a262c 100644 --- a/axlearn/cloud/gcp/measurement_test.py +++ b/axlearn/cloud/gcp/measurement_test.py @@ -3,191 +3,373 @@ """Tests measurement utils for GCP.""" # pylint: disable=protected-access -import contextlib from unittest import mock -from absl import flags +from absl import flags, logging from absl.testing import parameterized from axlearn.cloud.gcp.measurement import GoodputRecorder from axlearn.common import measurement +from axlearn.common.config import RequiredFieldMissingError class GoodputRecorderTest(parameterized.TestCase): """Tests GoodputRecorder.""" @parameterized.parameters( - (None,), (["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"],) - ) - def test_from_flags(self, spec): - fv = flags.FlagValues() - measurement.define_flags(flag_values=fv) - if spec is not None: - fv.set_default("recorder_spec", spec) - fv.mark_as_parsed() - - if spec is None: - ctx = self.assertRaisesRegex(ValueError, "name") - else: - ctx = contextlib.nullcontext() - - with ctx: - recorder = GoodputRecorder.from_flags(fv) - # Recorder is not instantiated until first event. - self.assertIsNone(recorder._recorder) - - def test_record_and_monitor(self): - fv = flags.FlagValues() - measurement.define_flags(flag_values=fv) - fv.set_default( - "recorder_spec", - ["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"], - ) - fv.mark_as_parsed() - - recorder = GoodputRecorder.from_flags(fv) - recorder._recorder = mock.MagicMock() - recorder.record(measurement.Event.START_JOB) - self.assertTrue(recorder._recorder.record_job_start_time.called) - - def test_start_goodput_monitoring(self): - fv = flags.FlagValues() - measurement.define_flags(flag_values=fv) - fv.set_default( - "recorder_spec", - [ + dict( + recorder_spec=[ "name=test-name", - "upload_dir=/test/path/to/upload", + "upload_dir=/test/path", "upload_interval=15", - "step_deviation_interval_seconds=-1", ], - ) - fv.mark_as_parsed() - - recorder = GoodputRecorder.from_flags(fv) - self.assertIsNone(recorder._monitor) # Ensure _monitor is initially None - - with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_goodput_monitor: - with mock.patch("ml_goodput_measurement.monitoring.GCPOptions") as mock_gcp_options: - mock_monitor_instance = mock_goodput_monitor.return_value - recorder.start_monitoring() - mock_gcp_options.assert_called_once_with( - enable_gcp_goodput_metrics=True, - enable_gcp_step_deviation_metrics=False, - ) - mock_gcp_options_instance = mock_gcp_options.return_value - - # Check that GoodputMonitor was instantiated - mock_goodput_monitor.assert_called_once_with( - job_name="test-name", - logger_name="goodput_logger_test-name", - tensorboard_dir="/test/path/to/upload", - upload_interval=15, - monitoring_enabled=True, - include_badput_breakdown=True, - include_step_deviation=False, - step_deviation_interval_seconds=-1, - gcp_options=mock_gcp_options_instance, - ) - - # Ensure that start_goodput_uploader is called on the monitor instance - mock_monitor_instance.start_goodput_uploader.assert_called_once() - self.assertIsNotNone(recorder._monitor) - - def test_start_goodput_and_step_deviation_monitoring(self): - fv = flags.FlagValues() - measurement.define_flags(flag_values=fv) - fv.set_default( - "recorder_spec", - [ + expected_rolling_window_size=[], + expected_jax_backend=None, + ), + dict( + recorder_spec=[ "name=test-name", - "upload_dir=/test/path/to/upload", + "upload_dir=/test/path", "upload_interval=15", - "step_deviation_interval_seconds=30", + "rolling_window_size=1,2,3", + "jax_backend=proxy", ], + expected_rolling_window_size=[1, 2, 3], + expected_jax_backend="proxy", + ), + ) + def test_from_flags( + self, + recorder_spec, + expected_rolling_window_size, + expected_jax_backend, + ): + """Tests that flags are correctly parsed into the config.""" + mock_fv = mock.MagicMock(spec=flags.FlagValues) + mock_fv.recorder_spec = recorder_spec + mock_fv.jax_backend = "tpu" + + recorder = GoodputRecorder.from_flags(mock_fv) + + self.assertEqual("test-name", recorder.config.name) + self.assertEqual("/test/path", recorder.config.upload_dir) + self.assertEqual(15, recorder.config.upload_interval) + self.assertEqual(expected_rolling_window_size, recorder.config.rolling_window_size) + self.assertEqual(expected_jax_backend, recorder.config.jax_backend) + + def test_from_flags_missing_required(self): + """Tests that missing required flags raise an error.""" + mock_fv = mock.MagicMock(spec=flags.FlagValues) + mock_fv.recorder_spec = ["name=test-name"] # Missing upload_dir/interval + mock_fv.jax_backend = "tpu" + with self.assertRaisesRegex(RequiredFieldMissingError, "upload_dir"): + GoodputRecorder.from_flags(mock_fv) + + @parameterized.parameters( + dict( + event=measurement.Event.JOB, + expected_start="record_job_start_time", + expected_end="record_job_end_time", + args=(), + kwargs={}, + expect_end_call=True, + ), + dict( + event=measurement.Event.STEP, + expected_start="record_step_start_time", + expected_end=None, + args=(123,), + kwargs={}, + expect_end_call=False, + ), + dict( + event=measurement.Event.ACCELERATOR_INIT, + expected_start="record_tpu_init_start_time", + expected_end="record_tpu_init_end_time", + args=(), + kwargs={}, + expect_end_call=True, + ), + dict( + event=measurement.Event.TRAINING_PREPARATION, + expected_start="record_training_preparation_start_time", + expected_end="record_training_preparation_end_time", + args=(), + kwargs={}, + expect_end_call=True, + ), + dict( + event=measurement.Event.DATA_LOADING, + expected_start="record_data_loading_start_time", + expected_end="record_data_loading_end_time", + args=(), + kwargs={}, + expect_end_call=True, + ), + dict( + event=measurement.Event.CUSTOM_BADPUT_EVENT, + expected_start="record_custom_badput_event_start_time", + expected_end="record_custom_badput_event_end_time", + args=(), + kwargs={"custom_badput_event_type": "TEST_TYPE"}, + expect_end_call=True, + ), + ) + @mock.patch("jax.process_index", return_value=0) + def test_record_event_context_manager_success( + self, _, event, expected_start, expected_end, args, kwargs, expect_end_call + ): + """Tests that record_event calls correct start and end methods with args and kwargs.""" + cfg = GoodputRecorder.default_config().set( + name="test", + upload_dir="/tmp/test", + upload_interval=1, ) - fv.mark_as_parsed() - - recorder = GoodputRecorder.from_flags(fv) - self.assertIsNone(recorder._monitor) # Ensure _monitor is initially None - - with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_goodput_monitor: - with mock.patch("ml_goodput_measurement.monitoring.GCPOptions") as mock_gcp_options: - mock_monitor_instance = mock_goodput_monitor.return_value - recorder.start_monitoring() - mock_gcp_options.assert_called_once_with( - enable_gcp_goodput_metrics=True, - enable_gcp_step_deviation_metrics=True, - ) - mock_gcp_options_instance = mock_gcp_options.return_value - - # Check that GoodputMonitor was instantiated - mock_goodput_monitor.assert_called_once_with( - job_name="test-name", - logger_name="goodput_logger_test-name", - tensorboard_dir="/test/path/to/upload", - upload_interval=15, - monitoring_enabled=True, - include_badput_breakdown=True, - include_step_deviation=True, - step_deviation_interval_seconds=30, - gcp_options=mock_gcp_options_instance, - ) + recorder = GoodputRecorder(cfg) - # Ensure that start_goodput_uploader and start_step_deviation_uploader is called on - # the monitor instance - mock_monitor_instance.start_goodput_uploader.assert_called_once() - mock_monitor_instance.start_step_deviation_uploader.assert_called_once() - self.assertIsNotNone(recorder._monitor) - - def test_missing_required_flags(self): - fv = flags.FlagValues() - measurement.define_flags(flag_values=fv) - # Missing 'upload_dir' and 'upload_interval' from recorder_spec - fv.set_default("recorder_spec", ["name=test-name"]) # Incomplete config - fv.mark_as_parsed() - - # Expecting ValueError since 'upload_dir' and 'upload_interval' are required - with self.assertRaises(ValueError): - GoodputRecorder.from_flags(fv) - - def test_monitoring_initialization_failure(self): - fv = flags.FlagValues() - measurement.define_flags(flag_values=fv) - fv.set_default( - "recorder_spec", - ["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"], + with mock.patch("ml_goodput_measurement.goodput.GoodputRecorder") as mock_recorder_cls: + mock_instance = mock_recorder_cls.return_value + + start_mock = mock.MagicMock() + setattr(mock_instance, expected_start, start_mock) + if expect_end_call and expected_end: + end_mock = mock.MagicMock() + setattr(mock_instance, expected_end, end_mock) + + with recorder.record_event(event, *args, **kwargs): + pass + + mock_recorder_cls.assert_called_once() + start_mock.assert_called_once_with(*args, **kwargs) + if expect_end_call and expected_end: + end_mock.assert_called_once_with(*args, **kwargs) + + def test_record_event_context_manager_handles_runtime_error(self): + cfg = GoodputRecorder.default_config().set( + name="test", + upload_dir="/tmp/test", + upload_interval=1, + ) + recorder = GoodputRecorder(cfg) + + with mock.patch("jax.process_index", return_value=0): + with mock.patch( + "ml_goodput_measurement.goodput.GoodputRecorder" + ) as mock_recorder_cls, mock.patch.object(logging, "warning") as mock_warning: + mock_instance = mock_recorder_cls.return_value + + def raise_runtime_error(*args, **kwargs): + raise RuntimeError("mocked error") + + mock_instance.record_job_start_time.side_effect = raise_runtime_error + mock_instance.record_job_end_time.side_effect = raise_runtime_error + # Should not crash here. + with recorder.record_event(measurement.Event.JOB): + pass + + # Assert warnings were logged for start and end failures + assert mock_warning.call_count == 2 + start_call = mock_warning.call_args_list[0] + end_call = mock_warning.call_args_list[1] + + assert "Failed to record" in start_call.args[0] + assert "Failed to record" in end_call.args[0] + + @parameterized.parameters( + dict(is_pathways_job=False, mock_jax_backend="tpu"), + dict(is_pathways_job=True, mock_jax_backend="proxy"), + dict(is_pathways_job=False, mock_jax_backend=None), + ) + @mock.patch("jax.process_index", return_value=0) + def test_maybe_monitor_goodput(self, _, is_pathways_job, mock_jax_backend): + """Tests the _maybe_monitor_goodput context manager.""" + cfg = GoodputRecorder.default_config().set( + name="test-monitor", + upload_dir="/test", + upload_interval=30, + jax_backend=mock_jax_backend, ) - fv.mark_as_parsed() - - recorder = GoodputRecorder.from_flags(fv) - self.assertIsNone(recorder._monitor) - - # Mock a failure in initializing the GoodputMonitor - with mock.patch( - "ml_goodput_measurement.monitoring.GoodputMonitor", - side_effect=Exception("Failed to initialize GoodputMonitor"), - ): - with self.assertRaises(Exception): - recorder.start_monitoring() - self.assertIsNone(recorder._monitor) - - def test_non_zero_process_index(self): - fv = flags.FlagValues() - measurement.define_flags(flag_values=fv) - fv.set_default( - "recorder_spec", - ["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"], + recorder = GoodputRecorder(cfg) + + with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_monitor_cls: + mock_monitor_instance = mock_monitor_cls.return_value + with recorder._maybe_monitor_goodput(): + pass + + # Verify that GoodputMonitor was instantiated with the correct parameters. + mock_monitor_cls.assert_called_once_with( + job_name="test-monitor", + logger_name="goodput_logger_test-monitor", + tensorboard_dir="/test", + upload_interval=30, + monitoring_enabled=True, + pathway_enabled=is_pathways_job, + include_badput_breakdown=True, + ) + mock_monitor_instance.start_goodput_uploader.assert_called_once() + mock_monitor_instance.stop_goodput_uploader.assert_called_once() + + @parameterized.parameters( + dict( + is_rolling_window_enabled=True, + rolling_window_size=[10, 20], + is_pathways_job=False, + mock_jax_backend="tpu", + ), + dict( + is_rolling_window_enabled=False, + rolling_window_size=[], + is_pathways_job=False, + mock_jax_backend="tpu", + ), + dict( + is_rolling_window_enabled=True, + rolling_window_size=[50], + is_pathways_job=True, + mock_jax_backend="proxy", + ), + ) + @mock.patch("jax.process_index", return_value=0) + def test_maybe_monitor_rolling_window( + self, + mock_process_index, + is_rolling_window_enabled, + rolling_window_size, + is_pathways_job, + mock_jax_backend, + ): # pylint: disable=unused-argument + """Tests the rolling window monitoring.""" + cfg = GoodputRecorder.default_config().set( + name="test-rolling", + upload_dir="/test", + upload_interval=30, + rolling_window_size=rolling_window_size, + jax_backend=mock_jax_backend, ) - fv.mark_as_parsed() + recorder = GoodputRecorder(cfg) - recorder = GoodputRecorder.from_flags(fv) - self.assertIsNone(recorder._monitor) + with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_monitor_cls: + mock_monitor_instance = mock_monitor_cls.return_value + if not is_rolling_window_enabled: + with recorder._maybe_monitor_rolling_window_goodput(): + pass + mock_monitor_cls.assert_not_called() + return + with recorder._maybe_monitor_rolling_window_goodput(): + pass - with mock.patch("jax.process_index") as mock_process_index: - mock_process_index.return_value = 1 # Simulate a non-zero process index + mock_monitor_cls.assert_called_once_with( + job_name="test-rolling", + logger_name="goodput_logger_test-rolling", + tensorboard_dir="/test/rolling_window_test-rolling", + upload_interval=30, + monitoring_enabled=True, + pathway_enabled=is_pathways_job, + include_badput_breakdown=True, + ) + + mock_monitor_instance.start_rolling_window_goodput_uploader.assert_called_with( + rolling_window_size + ) + mock_monitor_instance.stop_rolling_window_goodput_uploader.assert_called_once() + + @mock.patch("jax.process_index", return_value=1) + def test_non_zero_process_index_skips_monitoring( + self, mock_process_index + ): # pylint: disable=unused-argument + """Tests that monitoring is skipped on non-zero process indices.""" + cfg = GoodputRecorder.default_config().set( + name="test", upload_dir="/test", upload_interval=30 + ) + recorder = GoodputRecorder(cfg) - try: - recorder.start_monitoring() - except AttributeError: - self.fail("AttributeError was raised unexpectedly.") + with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_monitor_cls: + # Test cumulative goodput monitoring. + with recorder._maybe_monitor_goodput(): + pass + mock_monitor_cls.assert_not_called() + + cfg_rolling = GoodputRecorder.default_config().set( + name="test-rolling-skip", + upload_dir="/test", + upload_interval=30, + rolling_window_size=[10, 20], + ) + recorder_rolling = GoodputRecorder(cfg_rolling) + with recorder_rolling._maybe_monitor_rolling_window_goodput(): + pass + mock_monitor_cls.assert_not_called() + + @parameterized.parameters( + dict( + rolling_window_size=[5, 10], + jax_backend="tpu", + expected_monitor_calls=2, # Cumulative & Rolling Window + expect_rolling=True, + expect_cumulative=True, + ), + dict( + rolling_window_size=[], + jax_backend="tpu", + expected_monitor_calls=1, # Cumulative only + expect_rolling=False, + expect_cumulative=True, + ), + dict( + rolling_window_size=[5, 10], + jax_backend=None, # Disables Pathways + expected_monitor_calls=2, + expect_rolling=True, + expect_cumulative=True, + ), + dict( + rolling_window_size=[], + jax_backend=None, + expected_monitor_calls=1, + expect_rolling=False, + expect_cumulative=True, + ), + ) + @mock.patch("jax.process_index", return_value=0) + def test_maybe_monitor_all_goodput( + self, + _, + rolling_window_size, + jax_backend, + expected_monitor_calls, + expect_rolling, + expect_cumulative, + ): + """Tests all goodput monitoring with various configs.""" + cfg = GoodputRecorder.default_config().set( + name="test-all", + upload_dir="/test", + upload_interval=30, + rolling_window_size=rolling_window_size, + jax_backend=jax_backend, + ) + recorder = GoodputRecorder(cfg) + + with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_monitor_cls: + mock_monitor_instance = mock_monitor_cls.return_value + + with recorder.maybe_monitor_all_goodput(): + pass + + self.assertEqual(mock_monitor_cls.call_count, expected_monitor_calls) + + if expect_cumulative: + mock_monitor_instance.start_goodput_uploader.assert_called_once() + mock_monitor_instance.stop_goodput_uploader.assert_called_once() + else: + mock_monitor_instance.start_goodput_uploader.assert_not_called() + mock_monitor_instance.stop_goodput_uploader.assert_not_called() + + if expect_rolling: + mock_monitor_instance.start_rolling_window_goodput_uploader.assert_called_once_with( + rolling_window_size + ) + mock_monitor_instance.stop_rolling_window_goodput_uploader.assert_called_once() + else: + mock_monitor_instance.start_rolling_window_goodput_uploader.assert_not_called() + mock_monitor_instance.stop_rolling_window_goodput_uploader.assert_not_called() diff --git a/axlearn/common/launch_trainer.py b/axlearn/common/launch_trainer.py index bba28533e..7470ad66c 100644 --- a/axlearn/common/launch_trainer.py +++ b/axlearn/common/launch_trainer.py @@ -2,6 +2,7 @@ """Utilities to launch a trainer.""" +import contextlib import json import os from typing import Any, Optional @@ -128,8 +129,8 @@ def get_trainer_config( return trainer_config -def run_trainer(trainer_config: SpmdTrainer.Config) -> Any: - measurement.record_event(measurement.Event.START_JOB) +def _run_trainer_impl(trainer_config: SpmdTrainer.Config) -> Any: + """Instantiates and runs the trainer.""" trainer_config_debug_string = trainer_config.debug_string() logging.info("Trainer config:\n%s", trainer_config_debug_string) if jax.process_index() == 0: @@ -149,6 +150,13 @@ def run_trainer(trainer_config: SpmdTrainer.Config) -> Any: trainer: SpmdTrainer = trainer_config.instantiate(parent=None) prng_key = jax.random.PRNGKey(seed=FLAGS.trainer_prng_seed) - output = trainer.run(prng_key) - measurement.record_event(measurement.Event.END_JOB) - return output + return trainer.run(prng_key) + + +def run_trainer(trainer_config: SpmdTrainer.Config) -> Any: + recorder = measurement.global_recorder + job_events_manager = ( + recorder.record_event(measurement.Event.JOB) if recorder else contextlib.nullcontext() + ) + with job_events_manager: + return _run_trainer_impl(trainer_config) diff --git a/axlearn/common/launch_trainer_main.py b/axlearn/common/launch_trainer_main.py index 2f617b4cd..8d170a950 100644 --- a/axlearn/common/launch_trainer_main.py +++ b/axlearn/common/launch_trainer_main.py @@ -13,7 +13,6 @@ def main(_): launch.setup() trainer_config = launch_trainer.get_trainer_config() trainer_config.set(recorder=config_for_function(lambda: measurement.global_recorder)) - measurement.start_monitoring() launch_trainer.run_trainer(trainer_config) diff --git a/axlearn/common/measurement.py b/axlearn/common/measurement.py index b0a40a85f..1d2a9dea7 100644 --- a/axlearn/common/measurement.py +++ b/axlearn/common/measurement.py @@ -2,6 +2,7 @@ """A library to measure e2e metrics like goodput.""" +import contextlib import enum import importlib from typing import Optional, TypeVar @@ -15,30 +16,20 @@ class Event(enum.Enum): """Event to be recorded. Attributes: - START_JOB: Start of job. - END_JOB: End of job. - START_STEP: Start of a training step. Should be recorded with `step` as a positional arg. - START_ACCELERATOR_INIT: Start of accelerator mesh initialization. - END_ACCELERATOR_INIT: End of accelerator mesh initialization. - START_TRAINING_PREPARATION: Start of training preparation. - END_TRAINING_PREPARATION: End of training preparation. - START_DATA_LOADING: Start of data loading. - END_DATA_LOADING: End of data loading. - START_CUSTOM_BADPUT_EVENT: Start of custom badput event. - END_CUSTOM_BADPUT_EVENT: End of custom badput event. + JOB: Start and end of the job. + STEP: Start of a training step. Should be recorded with `step` as a positional arg. + ACCELERATOR_INIT: Start and end of accelerator mesh initialization. + TRAINING_PREPARATION: Start and end of training preparation. + DATA_LOADING: Start and end of data loading. + CUSTOM_BADPUT_EVENT: Start and end of custom badput events. """ - START_JOB = "START_JOB" - END_JOB = "END_JOB" - START_STEP = "START_STEP" - START_ACCELERATOR_INIT = "START_ACCELERATOR_INIT" - END_ACCELERATOR_INIT = "END_ACCELERATOR_INIT" - START_TRAINING_PREPARATION = "START_TRAINING_PREPARATION" - END_TRAINING_PREPARATION = "END_TRAINING_PREPARATION" - START_DATA_LOADING = "START_DATA_LOADING" - END_DATA_LOADING = "END_DATA_LOADING" - START_CUSTOM_BADPUT_EVENT = "START_CUSTOM_BADPUT_EVENT" - END_CUSTOM_BADPUT_EVENT = "END_CUSTOM_BADPUT_EVENT" + JOB = "job" + STEP = "step" + ACCELERATOR_INIT = "tpu_init" + TRAINING_PREPARATION = "training_preparation" + DATA_LOADING = "data_loading" + CUSTOM_BADPUT_EVENT = "custom_badput_event" class Recorder(Configurable): @@ -59,9 +50,15 @@ def from_flags(cls, fv: Optional[flags.FlagValues]) -> "Recorder": """Converts flags to a recorder.""" raise NotImplementedError(cls) - def record(self, event: Event, *args, **kwargs): - """Records an event with the given name.""" - raise NotImplementedError(type(self)) + @contextlib.contextmanager + def record_event(self, event: Event, *args, **kwargs): + """A context manager to record the start and end of an event.""" + # pylint: disable=unnecessary-pass + # pylint: disable=unused-argument + try: + yield + finally: + pass def start_monitoring(self, **kwargs): """Starts computing and uploading metrics at some configured interval in the background.""" @@ -134,14 +131,6 @@ def initialize(fv: flags.FlagValues): ) -def record_event(event: Event): - """Records a global event.""" - if global_recorder is None: - logging.log_first_n(logging.INFO, "No recorder configured, ignoring events.", 1) - else: - global_recorder.record(event) - - def start_monitoring(): """Begins monitoring events as per global monitor functionality.""" if global_recorder is None: diff --git a/axlearn/common/measurement_test.py b/axlearn/common/measurement_test.py index c9043f20b..d36605f29 100644 --- a/axlearn/common/measurement_test.py +++ b/axlearn/common/measurement_test.py @@ -3,24 +3,30 @@ """Tests measurement utils.""" # pylint: disable=protected-access +import contextlib from unittest import mock from absl import flags from absl.testing import parameterized from axlearn.common import measurement +from axlearn.experiments.testdata.axlearn_common_measurement_test.dummy_recorder import ( + DummyRecorder as RealDummyRecorder, +) class UtilsTest(parameterized.TestCase): """Tests utils.""" def setUp(self): + super().setUp() self._orig_recorder = measurement.global_recorder - self._orig_recorders = measurement._recorders + self._orig_recorders = measurement._recorders.copy() measurement.global_recorder = None measurement._recorders = {} def tearDown(self): + super().tearDown() measurement.global_recorder = self._orig_recorder measurement._recorders = self._orig_recorders @@ -33,32 +39,25 @@ class DummyRecorder(measurement.Recorder): self.assertEqual(DummyRecorder, measurement._recorders.get("test")) - # Registering twice should fail. with self.assertRaisesRegex(ValueError, "already registered"): measurement.register_recorder("test")(DummyRecorder) @parameterized.parameters( - # No-op if no recorder_type provided. - dict( - recorder_type=None, - expected=None, - ), - dict( - recorder_type="test", - expected="Mock", - ), - # Try initializing from another module. + dict(recorder_type=None), + dict(recorder_type="test"), dict( recorder_type=( - f"axlearn.experiments.testdata.{__name__.replace('.', '_')}.dummy_recorder:" + "axlearn.experiments.testdata.axlearn_common_measurement_test.dummy_recorder:" "dummy_recorder" - ), - expected="DummyRecorder", + ) ), ) - def test_initialize(self, recorder_type, expected): - mock_recorder = mock.MagicMock() - measurement.register_recorder("test")(mock_recorder) + def test_initialize(self, recorder_type): + mock_recorder_cls = mock.MagicMock() + mock_recorder_instance = mock_recorder_cls.from_flags.return_value + mock_recorder_instance.record_event.return_value = contextlib.nullcontext() + measurement.register_recorder("test")(mock_recorder_cls) + measurement.register_recorder("dummy_recorder")(RealDummyRecorder) fv = flags.FlagValues() measurement.define_flags(flag_values=fv) @@ -69,24 +68,17 @@ def test_initialize(self, recorder_type, expected): measurement.initialize(fv) if recorder_type is None: - # global_recorder should not be initialized, and record_event should be no-op. self.assertIsNone(measurement.global_recorder) - measurement.record_event(measurement.Event.START_JOB) return recorder_name = recorder_type.split(":", 1)[-1] if recorder_name == "test": - self.assertTrue(mock_recorder.from_flags.called) - - self.assertIn(expected, str(measurement._recorders.get(recorder_name, None))) - self.assertIn(expected, str(measurement.global_recorder)) - - # Ensure that record_event does not fail. - with mock.patch.object(measurement.global_recorder, "record") as mock_record: - measurement.record_event(measurement.Event.START_JOB) - self.assertIn(measurement.Event.START_JOB, mock_record.call_args[0]) + self.assertEqual(mock_recorder_instance, measurement.global_recorder) + mock_recorder_cls.from_flags.assert_called_once() + elif recorder_name == "dummy_recorder": + self.assertIsNotNone(measurement.global_recorder) + self.assertIsInstance(measurement.global_recorder, RealDummyRecorder) - # Ensure that start_monitoring does not fail. with mock.patch.object( measurement.global_recorder, "start_monitoring" ) as mock_start_monitoring: diff --git a/axlearn/common/trainer.py b/axlearn/common/trainer.py index 1c6c710aa..9cbb37262 100644 --- a/axlearn/common/trainer.py +++ b/axlearn/common/trainer.py @@ -241,116 +241,121 @@ def __init__( self._device_monitor = maybe_instantiate(cfg.device_monitor) self._recorder = maybe_instantiate(cfg.recorder) self._is_initialized: bool = False - self._maybe_record_event(measurement.Event.START_ACCELERATOR_INIT) + # Accelerator initialization. + with self._record_event(measurement.Event.ACCELERATOR_INIT): + if cfg.model.dtype is None: + raise ValueError(f"dtype must be explicitly specified for {self.path()}.model") + if cfg.model.param_init is None: + cfg.model.param_init = DefaultInitializer.default_config() + logging.info( + "model.param_init is not specified. Default to DefaultInitializer: %s", + cfg.model.param_init, + ) - if cfg.model.dtype is None: - raise ValueError(f"dtype must be explicitly specified for {self.path()}.model") - if cfg.model.param_init is None: - cfg.model.param_init = DefaultInitializer.default_config() - logging.info( - "model.param_init is not specified. Default to DefaultInitializer: %s", - cfg.model.param_init, + self._per_param_train_dtype = maybe_instantiate( + canonicalize_per_param_dtype(cfg.train_dtype) ) - self._per_param_train_dtype = maybe_instantiate( - canonicalize_per_param_dtype(cfg.train_dtype) - ) - - # Create the device mesh. - if devices is None: - self._step_log( - "Devices: global=%s local=%s %s", - jax.device_count(), - jax.local_device_count(), - [device.platform for device in jax.local_devices()], - ) - else: - local_devices = [d for d in devices.flatten() if d.process_index == jax.process_index()] - self._step_log( - "Devices: global=%s local=%s %s", - len(devices), - len(local_devices), - [device.platform for device in local_devices], - ) - self._step_log("Mesh shape: %s", cfg.mesh_shape) - devices = ( - utils.create_device_mesh(mesh_shape=cfg.mesh_shape) if devices is None else devices - ) - mesh = jax.sharding.Mesh(devices, cfg.mesh_axis_names) - self._step_log("Global mesh: %s", mesh) - self._mesh = mesh - self._context_manager: Callable[[], ContextManager] = ( - maybe_instantiate(cfg.context_manager) or contextlib.nullcontext - ) - xsc_check_policy = None - if cfg.xsc_check_policy: - if jax.default_backend() != "tpu": - # XSC is currently only supported on TPU XLA backend. - logging.warning( - "xsc_check_policy was set for non-TPU XLA backend. Running without XSC." + # Create the device mesh. + if devices is None: + self._step_log( + "Devices: global=%s local=%s %s", + jax.device_count(), + jax.local_device_count(), + [device.platform for device in jax.local_devices()], ) else: - xsc_check_policy = maybe_instantiate(cfg.xsc_check_policy) - self._xsc_check_policy: Optional[Callable[[int], bool]] = xsc_check_policy - self._compiled_train_step: Optional[jax.stages.Compiled] = None - - # Create all children within the mesh context so that utils.input_partition_spec() works - # properly. - with self.mesh(): - self.input: Input = self._add_child( - "input", - maybe_set_config( - cfg.input, partition_spec=PartitionSpec(cfg.batch_axis_names), is_training=True - ), - ) - # Start from the beginning of the input dataset by default. - self._input_iter = iter(self.input.dataset()) - cfg.summary_writer.dir = cfg.summary_writer.dir or os.path.join( - cfg.dir, "summaries", "train_train" - ) - self._add_child("summary_writer", cfg.summary_writer) - self._add_child("model", cfg.model) - self._add_child("learner", cfg.learner) - cfg.checkpointer.dir = cfg.checkpointer.dir or os.path.join(cfg.dir, "checkpoints") - self._add_child("checkpointer", cfg.checkpointer) - if cfg.init_state_builder is not None: - self._add_child("init_state_builder", cfg.init_state_builder) - - self._model_param_specs = self.model.create_parameter_specs_recursively() - model_param_partition_specs = jax.tree.map( - lambda spec: spec.mesh_axes, self._model_param_specs - ) - for name, spec in utils.flatten_items(self._model_param_specs): - self._step_log("Model param spec: %s=%s", name, spec) - self._learner_state_partition_specs = self.learner.create_state_partition_specs( - self._model_param_specs - ) - for name, spec in utils.flatten_items(self._learner_state_partition_specs): - self._step_log("Learner state spec: %s=%s", name, spec) - self._trainer_state_specs = TrainerState( - prng_key=ParameterSpec(dtype=jnp.uint32, shape=[4], mesh_axes=PartitionSpec(None)), - model=self._model_param_specs, - learner=self._learner_state_partition_specs, + local_devices = [ + d for d in devices.flatten() if d.process_index == jax.process_index() + ] + self._step_log( + "Devices: global=%s local=%s %s", + len(devices), + len(local_devices), + [device.platform for device in local_devices], + ) + self._step_log("Mesh shape: %s", cfg.mesh_shape) + devices = ( + utils.create_device_mesh(mesh_shape=cfg.mesh_shape) if devices is None else devices ) - self._trainer_state_partition_specs: TrainerState = jax.tree.map( - lambda spec: spec.sharding, self._trainer_state_specs + mesh = jax.sharding.Mesh(devices, cfg.mesh_axis_names) + self._step_log("Global mesh: %s", mesh) + self._mesh = mesh + self._context_manager: Callable[[], ContextManager] = ( + maybe_instantiate(cfg.context_manager) or contextlib.nullcontext ) - # Create evalers, which depend on model_param_partition_specs. - self._evalers = {} - for evaler_name, evaler_cfg in cfg.evalers.items(): - evaler_cfg.summary_writer.dir = evaler_cfg.summary_writer.dir or os.path.join( - cfg.dir, "summaries", evaler_name + xsc_check_policy = None + if cfg.xsc_check_policy: + if jax.default_backend() != "tpu": + # XSC is currently only supported on TPU XLA backend. + logging.warning( + "xsc_check_policy was set for non-TPU XLA backend. Running without XSC." + ) + else: + xsc_check_policy = maybe_instantiate(cfg.xsc_check_policy) + self._xsc_check_policy: Optional[Callable[[int], bool]] = xsc_check_policy + self._compiled_train_step: Optional[jax.stages.Compiled] = None + + # Create all children within the mesh context so that utils.input_partition_spec() works + # properly. + with self.mesh(): + if cfg.batch_axis_names is not None: + cfg.input = maybe_set_config( + cfg.input, partition_spec=PartitionSpec(cfg.batch_axis_names) + ) + self.input: Input = self._add_child( + "input", maybe_set_config(cfg.input, is_training=True) + ) + # Start from the beginning of the input dataset by default. + self._input_iter = iter(self.input.dataset()) + cfg.summary_writer.dir = cfg.summary_writer.dir or os.path.join( + cfg.dir, "summaries", "train_train" ) - maybe_set_config( - evaler_cfg.input, partition_spec=PartitionSpec(cfg.batch_axis_names) + self._add_child("summary_writer", cfg.summary_writer) + self._add_child("model", cfg.model) + self._add_child("learner", cfg.learner) + cfg.checkpointer.dir = cfg.checkpointer.dir or os.path.join(cfg.dir, "checkpoints") + self._add_child("checkpointer", cfg.checkpointer) + if cfg.init_state_builder is not None: + self._add_child("init_state_builder", cfg.init_state_builder) + + self._model_param_specs = self.model.create_parameter_specs_recursively() + model_param_partition_specs = jax.tree.map( + lambda spec: spec.mesh_axes, self._model_param_specs ) - self._evalers[evaler_name] = self._add_child( - evaler_name, - evaler_cfg, - model=self.model, - model_param_partition_specs=model_param_partition_specs, + for name, spec in utils.flatten_items(self._model_param_specs): + self._step_log("Model param spec: %s=%s", name, spec) + self._learner_state_partition_specs = self.learner.create_state_partition_specs( + self._model_param_specs ) - self._maybe_record_event(measurement.Event.END_ACCELERATOR_INIT) + for name, spec in utils.flatten_items(self._learner_state_partition_specs): + self._step_log("Learner state spec: %s=%s", name, spec) + self._trainer_state_specs = TrainerState( + prng_key=ParameterSpec( + dtype=jnp.uint32, shape=[4], mesh_axes=PartitionSpec(None) + ), + model=self._model_param_specs, + learner=self._learner_state_partition_specs, + ) + self._trainer_state_partition_specs: TrainerState = jax.tree.map( + lambda spec: spec.sharding, self._trainer_state_specs + ) + # Create evalers, which depend on model_param_partition_specs. + self._evalers = {} + for evaler_name, evaler_cfg in cfg.evalers.items(): + evaler_cfg.summary_writer.dir = evaler_cfg.summary_writer.dir or os.path.join( + cfg.dir, "summaries", evaler_name + ) + if cfg.batch_axis_names is not None: + maybe_set_config( + evaler_cfg.input, partition_spec=PartitionSpec(cfg.batch_axis_names) + ) + self._evalers[evaler_name] = self._add_child( + evaler_name, + evaler_cfg, + model=self.model, + model_param_partition_specs=model_param_partition_specs, + ) @property def step(self): @@ -368,6 +373,15 @@ def trainer_state_specs(self): def trainer_state_partition_specs(self): return self._trainer_state_partition_specs + @contextlib.contextmanager + def _record_event(self, event: measurement.Event, *args, **kwargs): + """A helper to record an event if a recorder is configured.""" + if self._recorder: + with self._recorder.record_event(event, *args, **kwargs) as event_manager: + yield event_manager + else: + yield + def _train_step_input_partition_specs(self): # Note that subclasses may override this method to set a partition spec for pjit which is # different from that of the input partition spec. @@ -525,10 +539,6 @@ def _should_force_run_evals( ) return force_run_evals - def _maybe_record_event(self, event: measurement.Event, *args, **kwargs): - if self._recorder is not None: - self._recorder.record(event, *args, **kwargs) - # pylint: disable-next=too-many-statements,too-many-branches def run( self, prng_key: Tensor, *, return_evaler_summaries: Optional[Union[bool, set[str]]] = None @@ -554,6 +564,7 @@ def run( different types of values such as WeightedScalar, Tensor, or string, depending on the specific `metric_calculator` config of the evaler. """ + with ( ( self._device_monitor.start_monitoring() @@ -564,6 +575,7 @@ def run( self.mesh(), jax.log_compiles(self.vlog_is_on(1)), self._context_manager(), + self._recorder.maybe_monitor_all_goodput(), ): cfg = self.config # Check if need to force run evals at the last training step. @@ -572,8 +584,9 @@ def run( ) # Prepare training. - if not self._prepare_training(prng_key): - return None + with self._record_event(measurement.Event.TRAINING_PREPARATION): + if not self._prepare_training(prng_key): + return None self._is_initialized = True @@ -586,10 +599,10 @@ def run( input_iterator = self.input.batches(self._input_iter) while True: - self._maybe_record_event(measurement.Event.START_DATA_LOADING) try: - input_batch = next(input_iterator) - self._maybe_record_event(measurement.Event.END_DATA_LOADING) + with self._record_event(measurement.Event.DATA_LOADING): + input_batch = next(input_iterator) + logging.log_first_n( logging.INFO, "input_batch=%s", 3, utils.shapes(input_batch) ) @@ -599,18 +612,18 @@ def run( self._step = self._step + 1 self.vlog(3, "Start step %s", self.step) - self._maybe_record_event(measurement.Event.START_STEP, self._step) - output = self._run_step( - utils.host_to_global_device_array( - input_batch, - partition=self._train_step_input_partition_specs(), - ), - force_run_evals=( - force_run_eval_sets_at_max_step - if self.step >= cfg.max_step - else None - ), - ) + with self._record_event(measurement.Event.STEP, self._step): + output = self._run_step( + utils.host_to_global_array( + input_batch, + partition=self._train_step_input_partition_specs(), + ), + force_run_evals=( + force_run_eval_sets_at_max_step + if self.step >= cfg.max_step + else None + ), + ) self.vlog(3, "Done step %s", self.step) num_steps += 1 if num_steps % 1 == 0: @@ -624,9 +637,6 @@ def run( self._step_log("Reached max_step=%s. Stopping", cfg.max_step) break except StopIteration: - # Add END_DATA_LOADING event here to close the unpaired START_DATA_LOADING - # event. - self._maybe_record_event(measurement.Event.END_DATA_LOADING) break if self.step < cfg.max_step: self._step_log("Reached end of inputs. Stopping") @@ -867,7 +877,6 @@ def _prepare_training(self, prng_key: Tensor) -> bool: A boolean indicating whether the model training should start. If not, return None from the `run` function. """ - self._maybe_record_event(measurement.Event.START_TRAINING_PREPARATION) cfg = self.config # Attempt to restore the latest checkpoint, which may contain a saved `_input_iter`. @@ -900,7 +909,6 @@ def _prepare_training(self, prng_key: Tensor) -> bool: return False self._jit_train_step = self._pjit_train_step() - self._maybe_record_event(measurement.Event.END_TRAINING_PREPARATION) return True def restore_checkpoint(self, restore_step: Optional[int] = None) -> Optional[int]: @@ -1039,36 +1047,29 @@ def _get_compiled_train_step_fn( mesh_shape=cfg.mesh_shape, mesh_axis_names=cfg.mesh_axis_names, device_kind=device_kind ) if not with_xsc: - self._maybe_record_event( - measurement.Event.START_CUSTOM_BADPUT_EVENT, + with self._record_event( + measurement.Event.CUSTOM_BADPUT_EVENT, custom_badput_event_type="COMPILATION_NO_XSC", - ) - self._compiled_train_step = self.compile_train_step( - trainer_state=trainer_state, input_batch=input_batch, compiler_options=options - ) - self._maybe_record_event( - measurement.Event.END_CUSTOM_BADPUT_EVENT, - custom_badput_event_type="COMPILATION_NO_XSC", - ) + ): + self._compiled_train_step = self.compile_train_step( + trainer_state=trainer_state, input_batch=input_batch, compiler_options=options + ) return self._compiled_train_step + logging.log_first_n(logging.INFO, "Compiling XSC train step.", 1) - self._maybe_record_event( - measurement.Event.START_CUSTOM_BADPUT_EVENT, + with self._record_event( + measurement.Event.CUSTOM_BADPUT_EVENT, custom_badput_event_type="COMPILATION_WITH_XSC", - ) - compiled_jit_train_step_fn = self.compile_train_step( - trainer_state=trainer_state, - input_batch=input_batch, - compiler_options=options - | infer_xsc_compiler_options( - halt_on_detection=True, repeat_count=1, device_kind=device_kind - ), - ) - self._maybe_record_event( - measurement.Event.END_CUSTOM_BADPUT_EVENT, - custom_badput_event_type="COMPILATION_WITH_XSC", - ) + ): + compiled_jit_train_step_fn = self.compile_train_step( + trainer_state=trainer_state, + input_batch=input_batch, + compiler_options=options + | infer_xsc_compiler_options( + halt_on_detection=True, repeat_count=1, device_kind=device_kind + ), + ) return compiled_jit_train_step_fn def _run_step( @@ -1125,26 +1126,23 @@ def _run_eval( force_runs: Optional[set[str]] = None, ) -> dict[str, Any]: """Runs evaluations and returns the corresponding summaries.""" - self._maybe_record_event( - measurement.Event.START_CUSTOM_BADPUT_EVENT, custom_badput_event_type="EVAL" - ) - evaler_summaries = {} - # Note: we will use the same eval key as the training keys of the future step, - # which should be okay. - prng_key = self._trainer_state.prng_key - for evaler_name, evaler in self._evalers.items(): - prng_key, summaries, _ = evaler.eval_step( - self.step, - prng_key=prng_key, - model_params=self.model_params_for_eval(), - train_summaries=train_summaries, - force_run=bool(force_runs is not None and evaler_name in force_runs), - ) - evaler_summaries[evaler_name] = summaries - self._maybe_record_event( - measurement.Event.END_CUSTOM_BADPUT_EVENT, custom_badput_event_type="EVAL" - ) - return evaler_summaries + with self._record_event( + measurement.Event.CUSTOM_BADPUT_EVENT, custom_badput_event_type="EVAL" + ): + evaler_summaries = {} + # Note: we will use the same eval key as the training keys of the future step, + # which should be okay. + prng_key = self._trainer_state.prng_key + for evaler_name, evaler in self._evalers.items(): + prng_key, summaries, _ = evaler.eval_step( + self.step, + prng_key=prng_key, + model_params=self.model_params_for_eval(), + train_summaries=train_summaries, + force_run=bool(force_runs is not None and evaler_name in force_runs), + ) + evaler_summaries[evaler_name] = summaries + return evaler_summaries def _pjit_train_step(self) -> jax.stages.Wrapped: return pjit( diff --git a/docs/05-Goodput-Monitoring.md b/docs/05-Goodput-Monitoring.md index ca1452c19..cb17f6989 100644 --- a/docs/05-Goodput-Monitoring.md +++ b/docs/05-Goodput-Monitoring.md @@ -1,10 +1,14 @@ # ML Goodput Monitoring -AXLearn supports automatic measurement and upload of workload metrics such as -Goodput, Badput Breakdown and Step Time Deviation using the ML Goodput -Measurement library. +AXLearn supports automatic measurement and upload of a wide range of workload +metrics using the **ML Goodput Measurement** library. This includes: +* **Goodput** and **Badput Breakdown** +* **Step Metrics** (Ideal Step Time, Step Time Deviation, Last Productive Step etc.) +* **Workload Hang Metrics** (Disruption Count, Step Info) +* **Rolling Window Goodput & Badput Breakdown** The [ML Goodput Measurement](https://github.com/AI-Hypercomputer/ml-goodput-measurement) library currently supports monitoring workloads running on Google Cloud Platform. For more information on details of the library, visit the Github page or the [ml-goodput-measurement](https://pypi.org/project/ml-goodput-measurement/) PyPI package documentation. + ### What is Goodput Goodput is the metric that measures the efficiency of model training jobs, i.e. productive time spent on training progress proportional to the total time spent @@ -15,12 +19,26 @@ improve to get the most value from their accelerators. Badput is the metric that measures time that a workload spent on anything that is not productive training proportional to the total time spent by the workload. For example, the time spent in accelerator initialization, training preparation, -program startup, data loading, portions of checkpointing, disruptions and -wasted progress since the last checkpoint etc. all contribute to Badput. +program startup, data loading, portions of checkpointing, recovering from +disruptions, wasted progress since the last checkpoint etc. all contribute to Badput. + +The ML Goodput Measurement library exposes Badput Breakdown. Further details of +each bucket can be found [here](https://github.com/AI-Hypercomputer/ml-goodput-measurement?tab=readme-ov-file#badput-breakdown-details) + +## What is Rolling Window Goodput & Badput +The ML Goodput Measurement library allows users to monitor goodput and badput +breakdown metrics within specific, moving time windows. You can specify a list +of rolling window interval sizes in seconds, and the library will asynchronously +query and upload metrics calculated only within the context of those windows. +This is useful for understanding workload performance over recent, specific +durations (e.g., the last 24 hours). -The ML Goodput Measurement library exposes Badput Breakdown. Further details of each bucket can be found [here](https://github.com/AI-Hypercomputer/ml-goodput-measurement?tab=readme-ov-file#badput-breakdown-details) +If the workload's actual runtime timeline is shorter than a requested window size, +the entire runtime timeline of the workload is used for the metrics computation. -### What is Step Time Deviation +> **Note**: Both the standard (cumulative) and rolling window query APIs can be enabled simultaneously to get a complete picture of your workload's performance. + +### What are Ideal Step Time and Step Time Deviation Step Time Deviation is the metric that measures deviation of step time (in seconds) from ideal step time. It is the difference between the actual time @@ -33,8 +51,8 @@ The formula for step deviation is: Ideal step time is equal to the user-configured `ideal_step_time` if it is provided. If the user has not specified an ideal step time, then the ideal step -time is calculated as the average of the "normal" step times recorded for the -workload, where a "normal" step is defined as having a duration less than or +time is calculated as a weighted average of the "normal" step times recorded for +the workload, where a "normal" step is defined as having a duration less than or equal to `median + median absolute deviation * 3` of the sample space of step times. This computation requires at least 10 recorded steps. @@ -77,7 +95,7 @@ project, then do the following: Please use a unique workload name, unless you intend to monitor cumulative Goodput/Badput metrics of a previous workload along with your current workload. -### How to Monitor Goodput and Badput +### How to Monitor Cumulative Goodput Metrics To enable Goodput recording and monitoring on AXLearn, follow the example below. @@ -94,24 +112,22 @@ To enable Goodput recording and monitoring on AXLearn, follow the example below. --recorder_spec=upload_interval=30 \ ``` -### How to Monitor Step Time Deviation +### How to Monitor Rolling Window Goodput Metrics -AXLearn enables step time deviation monitoring by default. You can configure -the upload frequency by setting -`--recorder_spec=step_deviation_interval_seconds=30`. To disable step deviation -set `--recorder_spec=step_deviation_interval_seconds=-1`. +To enable rolling window metrics, set `enable_rolling_window_goodput_monitoring` to `True` +and provide a list of interval sizes for `rolling_window_size` in seconds: ```bash - axlearn gcp launch run --instance_type=tpu-v5litepod-16 \ +axlearn gcp launch run --instance_type=tpu-v5litepod-16 \ --bundler_type=artifactregistry --bundler_spec=image=tpu \ --bundler_spec=dockerfile=Dockerfile \ - --name= \ - -- python3 -m ...training-config... \ + -- python3 -m my_training_job \ --recorder_type=axlearn.cloud.gcp.measurement:goodput \ --recorder_spec=name= \ --recorder_spec=upload_dir=my-output-directory/summaries \ --recorder_spec=upload_interval=30 \ - --recorder_spec=step_deviation_interval_seconds=30 \ + --recorder_spec=enable_rolling_window_goodput_monitoring=True \ + --recorder_spec=rolling_window_size=86400,259200,432000 ``` ### Visualize on Tensorboard @@ -121,12 +137,16 @@ set `--recorder_spec=step_deviation_interval_seconds=-1`. ### Enabling Google Cloud Monitoring -AXLearn has an additional option of pushing goodput, badput and step time -deviation metrics to Google Cloud Monitoring. By default if goodput monitoring -is enabled, the data gets published to Google Cloud Monitoring. Set the variables -`enable_gcp_goodput_metrics` and `enable_gcp_step_deviation_metrics` to `False` in -`goodput_monitoring.GCPOptions` in `cloud/gcp/measurement.py` to disable goodput and step_deviation -uploads to GCM respectively. +By default, when Goodput monitoring is enabled via the recorder, AXLearn automatically pushes metrics to Google Cloud Monitoring. + +- **Cumulative Metrics** are enabled by default when you specify the `recorder_type`. + To disable this, you would need to set `enable_gcp_goodput_metrics` to `False` in + `goodput_monitoring.GCPOptions` within the `cloud/gcp/measurement.py` file. +- **Rolling Window Metrics** can be explicitly enabled by setting + `enable_rolling_window_goodput_monitoring` to `True` and providing window sizes + via `rolling_window_size`. + +You can enable either cumulative monitoring, rolling window monitoring, or both simultaneously. ```bash axlearn gcp launch run --instance_type=tpu-v5litepod-16 \ @@ -138,7 +158,8 @@ uploads to GCM respectively. --recorder_spec=name= \ --recorder_spec=upload_dir=my-output-directory/summaries \ --recorder_spec=upload_interval=30 \ - --recorder_spec=step_deviation_interval_seconds=30 \ + --recorder_spec=enable_rolling_window_goodput_monitoring=True \ + --recorder_spec=rolling_window_size=86400,604800 ``` #### Visualization in Google Cloud Monitoring @@ -159,3 +180,38 @@ To visualize the collected metrics within Google Cloud Monitoring: c. [**Performance:**](https://cloud.google.com/monitoring/api/metrics_gcp#:~:text=workload/performance) Represents the workload's performance metric, specifically step deviation in this context, measured by `compute.googleapis.com/workload/performance`. + +#### Google Cloud Monitoring Dashboard: Goodput Monitor + +Following are instructions for deploying a custom dashboard `goodput_dashboard.json` +to your Google Cloud project's Monitoring console. This dashboard +offers a comprehensive view of "Goodput" metrics, helping you monitor the +your workloads and set up custom alerts for "events" such as performance degradation. + + +#### Deployment Steps + +Follow these steps to create a new custom dashboard using the provided JSON +configuration: + +1. **Navigate to the Monitoring Console**: In your Google Cloud project, + go to the **Monitoring** section. From the left-hand navigation menu, + select **Dashboards**. + +2. **Create Custom Dashboard**: Click the **Create Custom Dashboard** button. + +3. **Use JSON Editor**: In the new dashboard interface, select the + **JSON editor** option. + +4. **Copy and Save Configuration**: Open the [goodput_dashboard.json](https://github.com/AI-Hypercomputer/ml-goodput-measurement/blob/main/ml_goodput_measurement/dashboards/goodput_dashboard.json) file. + Copy its entire content and paste it into the JSON editor. Once pasted, + click **Save**. + + +Your "Goodput Monitor" dashboard should now be visible and operational within +your custom dashboards list. + +> **_NOTE:_** This dashboard is intended to be a starting point for your +> monitoring needs. We recommend customizing it to meet your specific needs. +> Please refer to the [Monitoring Dashboard documentation](https://cloud.google.com/monitoring/dashboards) +> for further guidance and customization options. diff --git a/pyproject.toml b/pyproject.toml index 0a9bce9ca..111287d3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,7 +99,7 @@ gcp = [ "google-cloud-compute==1.19.2", # Needed for region discovery for CloudBuild API access. "google-cloud-core==2.3.3", "google-cloud-build==3.24.1", - "ml-goodput-measurement==0.0.10", + "ml-goodput-measurement==0.0.13", "pika==1.3.2", # used by event queue "pyOpenSSL>=22.1.0", # compat with cryptography version. "tpu-info==0.2.0", # For TPU monitoring from libtpu. https://github.com/AI-Hypercomputer/cloud-accelerator-diagnostics/tree/main/tpu_info From 799c4a9f70b3a1feab9b3fa1960494df49e650a7 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Sun, 3 Aug 2025 22:58:37 -0700 Subject: [PATCH 37/57] switch command to new goodput library --- test-orbax.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test-orbax.sh b/test-orbax.sh index be3b4b5ec..f0f105907 100755 --- a/test-orbax.sh +++ b/test-orbax.sh @@ -61,7 +61,7 @@ if [[ "$CONFIG" == *"orbaxem"* ]]; then --recorder_spec=name=goodput_${JOBSET_NAME} \ --recorder_spec=upload_dir=${TRAINER_DIR}/summaries \ --recorder_spec=upload_interval=30 \ - --recorder_spec=step_deviation_interval_seconds=30 \ + --recorder_spec=rolling_window_size=3600,7200,10800,86400 \ --trace_at_steps=29,59,89,119,149,179,209,239,269,299,329,359,389,419,449,479,509,539,569,599,629,659,689,719 else From a138a9f06d3f76e41a90b8576d664dbf944c0f26 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Mon, 4 Aug 2025 11:11:29 -0700 Subject: [PATCH 38/57] print log every step --- axlearn/common/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/axlearn/common/trainer.py b/axlearn/common/trainer.py index 9cbb37262..4332d3a8e 100644 --- a/axlearn/common/trainer.py +++ b/axlearn/common/trainer.py @@ -1096,7 +1096,7 @@ def _run_step( # Run the compiled function. self._trainer_state, outputs = compiled_train_step_fn(self.trainer_state, input_batch) - if self.step % 100 == 0 or 0 <= self.step <= 5: + if self.step % 1 == 0 or 0 <= self.step <= 5: self._step_log( "loss=%s aux=%s", outputs["loss"], From 5cc91ee783f6193f715b494057e54e43af9d4149 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Mon, 4 Aug 2025 11:25:56 -0700 Subject: [PATCH 39/57] force deletion correctly for terminating pods --- force_delete_terminating_pods.sh | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/force_delete_terminating_pods.sh b/force_delete_terminating_pods.sh index dee311d8e..cc4665e50 100755 --- a/force_delete_terminating_pods.sh +++ b/force_delete_terminating_pods.sh @@ -30,12 +30,17 @@ while true; do pod_namespace=$(echo "$pod_json" | jq -r '.metadata.namespace') deletion_timestamp_str=$(echo "$pod_json" | jq -r '.metadata.deletionTimestamp') + # Sanitize the timestamp for macOS `date` by removing fractional seconds and the 'Z' suffix. + # This handles formats like "2024-01-01T12:34:56.123456Z" -> "2024-01-01T12:34:56" + sanitized_timestamp_str=$(echo "$deletion_timestamp_str" | sed -e 's/\.[0-9]*Z$/Z/' -e 's/Z$//') + # Convert the RFC3339 timestamp to a Unix epoch timestamp # Works on both GNU and BSD (macOS) date commands. if date --version >/dev/null 2>&1; then # GNU date - deletion_ts=$(date -d "$deletion_timestamp_str" +%s) + deletion_ts=$(date -d "$sanitized_timestamp_str" +%s) else # BSD date - deletion_ts=$(date -jf "%Y-%m-%dT%H:%M:%SZ" "$deletion_timestamp_str" +%s) + # On macOS, use -u to interpret the time as UTC. + deletion_ts=$(date -u -jf "%Y-%m-%dT%H:%M:%S" "$sanitized_timestamp_str" +%s) fi # Get the current time as a Unix epoch timestamp From ecbfa95b40842d33ebf09fce6e0adba202250292 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Tue, 5 Aug 2025 13:39:56 -0700 Subject: [PATCH 40/57] switch to debug cluster gcs bucket --- test-orbax.sh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test-orbax.sh b/test-orbax.sh index f0f105907..272609119 100755 --- a/test-orbax.sh +++ b/test-orbax.sh @@ -14,7 +14,9 @@ export MESH_SELECTOR=${MESH_SELECTOR:-"tpu-v6e-16"} # export CONFIG=${CONFIG:-"fuji-8B-v3-tiktoken-flash-orbax"} export CONFIG=${CONFIG:-"fuji-7B-v3-flash-orbaxem"} export PROJECT_ID=$(gcloud config get project) -export TRAINER_DIR=gs://tpu-prod-env-multipod-use4 +# export TRAINER_DIR=gs://tpu-prod-env-multipod-use4 +export TRAINER_DIR=gs://tpu-prod-env-one-vm-saw1-a + # Example for v6e-256 # MESH_SELECTOR=tpu-v6e-256-4 INSTANCE_TYPE=tpu-v6e-256 ./test-orbax.sh From eab4ed55e5997648bcb855530dcdd03e9eee3d1f Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Tue, 5 Aug 2025 13:48:07 -0700 Subject: [PATCH 41/57] bump orbax em fork --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 111287d3a..dfdc377d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -156,7 +156,7 @@ mmau = [ orbax = [ "humanize==4.10.0", # "orbax-checkpoint==0.11.20", - "orbax-checkpoint @ git+https://github.com/samos123/orbax.git@v0.11.20-em#subdirectory=checkpoint" + "orbax-checkpoint @ git+https://github.com/samos123/orbax.git@v0.12.0#subdirectory=checkpoint" ] # Audio dependencies. audio = [ From 9318bcd87a58266c6d5c665f94d539c8e15055c9 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Tue, 5 Aug 2025 16:05:37 -0700 Subject: [PATCH 42/57] turn off goodput logging since it needs more permissions --- test-orbax.sh | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/test-orbax.sh b/test-orbax.sh index 272609119..7f9e13744 100755 --- a/test-orbax.sh +++ b/test-orbax.sh @@ -34,7 +34,13 @@ export TRAINER_DIR=gs://tpu-prod-env-one-vm-saw1-a # --queue=multislice-queue \ # --priority_class=very-high \ # --trainer_dir=gs://tess-checkpoints-us-west1/${JOBSET_NAME}-nr-${NUM_REPLICAS}/ \ -# +# For goodput logging +# --recorder_type=axlearn.cloud.gcp.measurement:goodput \ +# --recorder_spec=name=goodput_${JOBSET_NAME} \ +# --recorder_spec=upload_dir=${TRAINER_DIR}/summaries \ +# --recorder_spec=upload_interval=30 \ +# --recorder_spec=rolling_window_size=3600,7200,10800,86400 \ +# --trace_at_steps=29,59,89,119,149,179,209,239,269,299,329,359,389,419,449,479,509,539,569,599,629,659,689,719 # Check if CONFIG ends with "orbaxem" if [[ "$CONFIG" == *"orbaxem"* ]]; then @@ -58,13 +64,7 @@ if [[ "$CONFIG" == *"orbaxem"* ]]; then --data_dir=gs://axlearn-public/tensorflow_datasets \ --jax_backend=tpu \ --mesh_selector=${MESH_SELECTOR} \ - --initialization_timeout=1200 \ - --recorder_type=axlearn.cloud.gcp.measurement:goodput \ - --recorder_spec=name=goodput_${JOBSET_NAME} \ - --recorder_spec=upload_dir=${TRAINER_DIR}/summaries \ - --recorder_spec=upload_interval=30 \ - --recorder_spec=rolling_window_size=3600,7200,10800,86400 \ - --trace_at_steps=29,59,89,119,149,179,209,239,269,299,329,359,389,419,449,479,509,539,569,599,629,659,689,719 + --initialization_timeout=1200 else echo "Running Orbax regular checkpointer or AXLearn native." From 2b11540ceb842591b95f6635d7716f4bf7511ae9 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Tue, 5 Aug 2025 16:23:33 -0700 Subject: [PATCH 43/57] Revert "Integrate AXLearn with latest Goodput package" This reverts commit 0c2ce00e3598a2a7b7dea69e806f201b4af4f123. --- axlearn/cloud/gcp/measurement.py | 202 ++++------ axlearn/cloud/gcp/measurement_test.py | 508 +++++++++----------------- axlearn/common/launch_trainer.py | 18 +- axlearn/common/launch_trainer_main.py | 1 + axlearn/common/measurement.py | 55 +-- axlearn/common/measurement_test.py | 52 +-- axlearn/common/trainer.py | 340 ++++++++--------- docs/05-Goodput-Monitoring.md | 108 ++---- pyproject.toml | 2 +- 9 files changed, 502 insertions(+), 784 deletions(-) diff --git a/axlearn/cloud/gcp/measurement.py b/axlearn/cloud/gcp/measurement.py index 0eb226e6f..0d4ce0069 100644 --- a/axlearn/cloud/gcp/measurement.py +++ b/axlearn/cloud/gcp/measurement.py @@ -2,9 +2,6 @@ """Measurement utils for GCP. - For detailed documentation and advanced usage, please refer to: - axlearn/docs/05-Goodput-Monitoring.md - Example: # Enable Goodput when launching an AXLearn training job @@ -16,14 +13,10 @@ --recorder_spec=name=my-run-with-goodput \ --recorder_spec=upload_dir=my-output-directory/summaries \ --recorder_spec=upload_interval=30 \ - --recorder_spec=rolling_window_size=86400,604800 + --recorder_spec=step_deviation_interval_seconds=30 """ -import contextlib -import os -from typing import Optional, Sequence - import jax from absl import flags, logging from ml_goodput_measurement import goodput @@ -45,19 +38,13 @@ class Config(measurement.Recorder.Config): Attributes: upload_dir: Directory to store metrics for the monitor. upload_interval: Time interval (seconds) for monitoring uploads. - See "How to Monitor Cumulative Goodput Metrics" in - docs/05-Goodput-Monitoring.md for more details. - rolling_window_size: A sequence of integers defining the rolling window sizes in - seconds. - See "How to Monitor Rolling Window Goodput Metrics" in - docs/05-Goodput-Monitoring.md for more details. - jax_backend: Jax backend type to infer Pathways environment. + step_deviation_interval_seconds: Time interval (seconds) for step deviation metrics + uploads. -1 to disable step deviation uploads. """ upload_dir: Required[str] = REQUIRED upload_interval: Required[int] = REQUIRED - rolling_window_size: Sequence[int] = [] - jax_backend: Optional[str] = None + step_deviation_interval_seconds: int = 30 # Default to 30 seconds @classmethod def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder": @@ -66,78 +53,68 @@ def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder": `fv.recorder_spec` will be interpreted as a list of `key=value` pairs; config names corresponding to keys will be set to the corresponding values. A GoodputRecorder can additionally take in following Tensorboard configs in the recorder_spec: - - upload_dir: The directory to write Tensorboard data to. - - upload_interval: The time interval in seconds at which to query and upload data - to Tensorboard. - - rolling_window_size: Comma-separated list of integers representing rolling window - sizes in seconds. - - jax_backend: The type of jax backend. + - upload_dir: The directory to write Tensorboard data to. + - upload_interval: The time interval in seconds at which to query and upload data + to Tensorboard. + - step_deviation_interval_seconds: Time interval (seconds) for step deviation metrics + uploads. Set to less than or equal to 0 to disable step deviation uploads. """ cfg: measurement.Recorder.Config = cls.default_config() - parsed_flags = parse_kv_flags(fv.recorder_spec, delimiter="=") - if "upload_interval" in parsed_flags: - parsed_flags["upload_interval"] = int(parsed_flags["upload_interval"]) - if "rolling_window_size" in parsed_flags and isinstance( - parsed_flags["rolling_window_size"], str - ): - parsed_flags["rolling_window_size"] = [ - int(x) for x in parsed_flags["rolling_window_size"].split(",") - ] - return maybe_set_config(cfg, **parsed_flags).instantiate() + cfg = maybe_set_config(cfg, **parse_kv_flags(fv.recorder_spec, delimiter="=")) + return cfg.instantiate() def __init__(self, cfg): super().__init__(cfg) - self._recorder: Optional[goodput.GoodputRecorder] = None - self._monitor: Optional[goodput_monitoring.GoodputMonitor] = None - self._rolling_window_monitor: Optional[goodput_monitoring.GoodputMonitor] = None - self._job_name = cfg.name - self._logger_name = f"goodput_logger_{cfg.name}" - - @contextlib.contextmanager - def record_event(self, event: measurement.Event, *args, **kwargs): - """Records a goodput event using a context manager.""" - # Lazily instantiate the recorder if it hasn't been already. + cfg: GoodputRecorder.Config = self.config + self._recorder = None + self._monitor = None + + def record(self, event: measurement.Event, *args, **kwargs): + # Lazily instantiate the recorder. This avoids invoking jax before setup is complete. if self._recorder is None: - if jax.process_index() == 0: - logging.info("Lazily instantiating goodput recorder.") + cfg: GoodputRecorder.Config = self.config self._recorder = goodput.GoodputRecorder( - job_name=self._job_name, - logger_name=self._logger_name, + job_name=cfg.name, + logger_name=f"goodput_logger_{cfg.name}", logging_enabled=(jax.process_index() == 0), ) - start_method_name = f"record_{event.value}_start_time" - end_method_name = f"record_{event.value}_end_time" - - record_event_start = getattr(self._recorder, start_method_name, None) - record_event_end = getattr(self._recorder, end_method_name, None) - - try: - if record_event_start: - record_event_start(*args, **kwargs) - except RuntimeError as e: - logging.warning( - "Failed to record start of event %s. Error: %s", event.value, e, exc_info=True + if event == measurement.Event.START_JOB: + self._recorder.record_job_start_time(*args, **kwargs) + elif event == measurement.Event.END_JOB: + self._recorder.record_job_end_time(*args, **kwargs) + elif event == measurement.Event.START_STEP: + self._recorder.record_step_start_time(*args, **kwargs) + elif event == measurement.Event.START_ACCELERATOR_INIT: + self._recorder.record_tpu_init_start_time(*args, **kwargs) + elif event == measurement.Event.END_ACCELERATOR_INIT: + self._recorder.record_tpu_init_end_time(*args, **kwargs) + elif event == measurement.Event.START_TRAINING_PREPARATION: + self._recorder.record_training_preparation_start_time(*args, **kwargs) + elif event == measurement.Event.END_TRAINING_PREPARATION: + self._recorder.record_training_preparation_end_time(*args, **kwargs) + elif event == measurement.Event.START_DATA_LOADING: + self._recorder.record_data_loading_start_time(*args, **kwargs) + elif event == measurement.Event.END_DATA_LOADING: + self._recorder.record_data_loading_end_time(*args, **kwargs) + elif event == measurement.Event.START_CUSTOM_BADPUT_EVENT: + self._recorder.record_custom_badput_event_start_time(*args, **kwargs) + elif event == measurement.Event.END_CUSTOM_BADPUT_EVENT: + self._recorder.record_custom_badput_event_end_time(*args, **kwargs) + else: + logging.log_first_n( + logging.WARNING, + "Ignoring unknown event %s", + 1, + event, ) - try: - yield - finally: - try: - if record_event_end: - record_event_end(*args, **kwargs) - except RuntimeError as e: - logging.warning( - "Failed to record end of event %s. Error: %s", event.value, e, exc_info=True - ) - - @contextlib.contextmanager - def _maybe_monitor_goodput(self, *args, **kwargs): - """Monitor cumulative goodput if enabled. + def start_monitoring(self, *args, **kwargs): + """Starts Monitoring of Goodput. Instantiate ml-goodput-measurement's GoodputMonitor to asynchronously calculate - Goodput, Badput, Step & Disruption Information at the upload_interval to the - specified TensorBoard directory and Google Cloud Monitoring. + Goodput and Badput at the upload_interval and upload to the specified TensorBoard + directory. Note: This function requires initialization of distributed JAX before it is called. If there are internal GCP errors from querying and uploading data, these will be logged without affecting the workload. GoodputMonitor logs will provide further @@ -146,68 +123,33 @@ def _maybe_monitor_goodput(self, *args, **kwargs): Default behavior is to push metrics to Google Cloud Monitoring. This behavior can be overridden by configuring `goodput_monitoring.GCPOptions` """ - if jax.process_index() != 0: - yield - return - try: + cfg: GoodputRecorder.Config = self.config + include_step_deviation = True + if jax.process_index() == 0: if self._monitor is None: + if int(cfg.step_deviation_interval_seconds) <= 0: + include_step_deviation = False + + gcp_options = goodput_monitoring.GCPOptions( + enable_gcp_goodput_metrics=True, + enable_gcp_step_deviation_metrics=include_step_deviation, + ) self._monitor = goodput_monitoring.GoodputMonitor( - job_name=self._job_name, - logger_name=self._logger_name, - tensorboard_dir=self.config.upload_dir, - upload_interval=self.config.upload_interval, + job_name=cfg.name, + logger_name=f"goodput_logger_{cfg.name}", + tensorboard_dir=cfg.upload_dir, + upload_interval=int(cfg.upload_interval), monitoring_enabled=True, - pathway_enabled=self.config.jax_backend == "proxy", include_badput_breakdown=True, + include_step_deviation=include_step_deviation, + step_deviation_interval_seconds=int(cfg.step_deviation_interval_seconds), + gcp_options=gcp_options, ) self._monitor.start_goodput_uploader(*args, **kwargs) logging.info("Started Goodput upload to Tensorboard & GCM in the background!") - yield - finally: - if self._monitor: - self._monitor.stop_goodput_uploader() - logging.info("Flushed final metrics and safe exited from Goodput monitoring.") - - @contextlib.contextmanager - def _maybe_monitor_rolling_window_goodput(self): - """Monitor rolling window goodput if enabled.""" - if not self.config.rolling_window_size or jax.process_index() != 0: - yield - return - try: - if self._rolling_window_monitor is None: - rolling_window_tensorboard_dir = os.path.join( - self.config.upload_dir, f"rolling_window_{self.config.name}" - ) - self._rolling_window_monitor = goodput_monitoring.GoodputMonitor( - job_name=self._job_name, - logger_name=self._logger_name, - tensorboard_dir=rolling_window_tensorboard_dir, - upload_interval=self.config.upload_interval, - monitoring_enabled=True, - pathway_enabled=self.config.jax_backend == "proxy", - include_badput_breakdown=True, - ) - self._rolling_window_monitor.start_rolling_window_goodput_uploader( - self.config.rolling_window_size - ) - logging.info("Started Rolling Window Goodput monitoring in the background!") - yield - finally: - if self._rolling_window_monitor: - self._rolling_window_monitor.stop_rolling_window_goodput_uploader() + if include_step_deviation: + self._monitor.start_step_deviation_uploader(*args, **kwargs) logging.info( - "Flushed final metrics and safe exited from Rolling Window Goodput monitoring." + "Started Step Deviation upload to Tensorboard & GCM in the background!" ) - - def maybe_monitor_all_goodput(self): - goodput_monitor_manager = self._maybe_monitor_goodput() - rolling_goodput_monitor_manager = self._maybe_monitor_rolling_window_goodput() - - @contextlib.contextmanager - def monitor_goodput(): - with goodput_monitor_manager, rolling_goodput_monitor_manager: - yield - - return monitor_goodput() diff --git a/axlearn/cloud/gcp/measurement_test.py b/axlearn/cloud/gcp/measurement_test.py index e944a262c..e14fc16c4 100644 --- a/axlearn/cloud/gcp/measurement_test.py +++ b/axlearn/cloud/gcp/measurement_test.py @@ -3,373 +3,191 @@ """Tests measurement utils for GCP.""" # pylint: disable=protected-access +import contextlib from unittest import mock -from absl import flags, logging +from absl import flags from absl.testing import parameterized from axlearn.cloud.gcp.measurement import GoodputRecorder from axlearn.common import measurement -from axlearn.common.config import RequiredFieldMissingError class GoodputRecorderTest(parameterized.TestCase): """Tests GoodputRecorder.""" @parameterized.parameters( - dict( - recorder_spec=[ + (None,), (["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"],) + ) + def test_from_flags(self, spec): + fv = flags.FlagValues() + measurement.define_flags(flag_values=fv) + if spec is not None: + fv.set_default("recorder_spec", spec) + fv.mark_as_parsed() + + if spec is None: + ctx = self.assertRaisesRegex(ValueError, "name") + else: + ctx = contextlib.nullcontext() + + with ctx: + recorder = GoodputRecorder.from_flags(fv) + # Recorder is not instantiated until first event. + self.assertIsNone(recorder._recorder) + + def test_record_and_monitor(self): + fv = flags.FlagValues() + measurement.define_flags(flag_values=fv) + fv.set_default( + "recorder_spec", + ["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"], + ) + fv.mark_as_parsed() + + recorder = GoodputRecorder.from_flags(fv) + recorder._recorder = mock.MagicMock() + recorder.record(measurement.Event.START_JOB) + self.assertTrue(recorder._recorder.record_job_start_time.called) + + def test_start_goodput_monitoring(self): + fv = flags.FlagValues() + measurement.define_flags(flag_values=fv) + fv.set_default( + "recorder_spec", + [ "name=test-name", - "upload_dir=/test/path", + "upload_dir=/test/path/to/upload", "upload_interval=15", + "step_deviation_interval_seconds=-1", ], - expected_rolling_window_size=[], - expected_jax_backend=None, - ), - dict( - recorder_spec=[ + ) + fv.mark_as_parsed() + + recorder = GoodputRecorder.from_flags(fv) + self.assertIsNone(recorder._monitor) # Ensure _monitor is initially None + + with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_goodput_monitor: + with mock.patch("ml_goodput_measurement.monitoring.GCPOptions") as mock_gcp_options: + mock_monitor_instance = mock_goodput_monitor.return_value + recorder.start_monitoring() + mock_gcp_options.assert_called_once_with( + enable_gcp_goodput_metrics=True, + enable_gcp_step_deviation_metrics=False, + ) + mock_gcp_options_instance = mock_gcp_options.return_value + + # Check that GoodputMonitor was instantiated + mock_goodput_monitor.assert_called_once_with( + job_name="test-name", + logger_name="goodput_logger_test-name", + tensorboard_dir="/test/path/to/upload", + upload_interval=15, + monitoring_enabled=True, + include_badput_breakdown=True, + include_step_deviation=False, + step_deviation_interval_seconds=-1, + gcp_options=mock_gcp_options_instance, + ) + + # Ensure that start_goodput_uploader is called on the monitor instance + mock_monitor_instance.start_goodput_uploader.assert_called_once() + self.assertIsNotNone(recorder._monitor) + + def test_start_goodput_and_step_deviation_monitoring(self): + fv = flags.FlagValues() + measurement.define_flags(flag_values=fv) + fv.set_default( + "recorder_spec", + [ "name=test-name", - "upload_dir=/test/path", + "upload_dir=/test/path/to/upload", "upload_interval=15", - "rolling_window_size=1,2,3", - "jax_backend=proxy", + "step_deviation_interval_seconds=30", ], - expected_rolling_window_size=[1, 2, 3], - expected_jax_backend="proxy", - ), - ) - def test_from_flags( - self, - recorder_spec, - expected_rolling_window_size, - expected_jax_backend, - ): - """Tests that flags are correctly parsed into the config.""" - mock_fv = mock.MagicMock(spec=flags.FlagValues) - mock_fv.recorder_spec = recorder_spec - mock_fv.jax_backend = "tpu" - - recorder = GoodputRecorder.from_flags(mock_fv) - - self.assertEqual("test-name", recorder.config.name) - self.assertEqual("/test/path", recorder.config.upload_dir) - self.assertEqual(15, recorder.config.upload_interval) - self.assertEqual(expected_rolling_window_size, recorder.config.rolling_window_size) - self.assertEqual(expected_jax_backend, recorder.config.jax_backend) - - def test_from_flags_missing_required(self): - """Tests that missing required flags raise an error.""" - mock_fv = mock.MagicMock(spec=flags.FlagValues) - mock_fv.recorder_spec = ["name=test-name"] # Missing upload_dir/interval - mock_fv.jax_backend = "tpu" - with self.assertRaisesRegex(RequiredFieldMissingError, "upload_dir"): - GoodputRecorder.from_flags(mock_fv) - - @parameterized.parameters( - dict( - event=measurement.Event.JOB, - expected_start="record_job_start_time", - expected_end="record_job_end_time", - args=(), - kwargs={}, - expect_end_call=True, - ), - dict( - event=measurement.Event.STEP, - expected_start="record_step_start_time", - expected_end=None, - args=(123,), - kwargs={}, - expect_end_call=False, - ), - dict( - event=measurement.Event.ACCELERATOR_INIT, - expected_start="record_tpu_init_start_time", - expected_end="record_tpu_init_end_time", - args=(), - kwargs={}, - expect_end_call=True, - ), - dict( - event=measurement.Event.TRAINING_PREPARATION, - expected_start="record_training_preparation_start_time", - expected_end="record_training_preparation_end_time", - args=(), - kwargs={}, - expect_end_call=True, - ), - dict( - event=measurement.Event.DATA_LOADING, - expected_start="record_data_loading_start_time", - expected_end="record_data_loading_end_time", - args=(), - kwargs={}, - expect_end_call=True, - ), - dict( - event=measurement.Event.CUSTOM_BADPUT_EVENT, - expected_start="record_custom_badput_event_start_time", - expected_end="record_custom_badput_event_end_time", - args=(), - kwargs={"custom_badput_event_type": "TEST_TYPE"}, - expect_end_call=True, - ), - ) - @mock.patch("jax.process_index", return_value=0) - def test_record_event_context_manager_success( - self, _, event, expected_start, expected_end, args, kwargs, expect_end_call - ): - """Tests that record_event calls correct start and end methods with args and kwargs.""" - cfg = GoodputRecorder.default_config().set( - name="test", - upload_dir="/tmp/test", - upload_interval=1, - ) - recorder = GoodputRecorder(cfg) - - with mock.patch("ml_goodput_measurement.goodput.GoodputRecorder") as mock_recorder_cls: - mock_instance = mock_recorder_cls.return_value - - start_mock = mock.MagicMock() - setattr(mock_instance, expected_start, start_mock) - if expect_end_call and expected_end: - end_mock = mock.MagicMock() - setattr(mock_instance, expected_end, end_mock) - - with recorder.record_event(event, *args, **kwargs): - pass - - mock_recorder_cls.assert_called_once() - start_mock.assert_called_once_with(*args, **kwargs) - if expect_end_call and expected_end: - end_mock.assert_called_once_with(*args, **kwargs) - - def test_record_event_context_manager_handles_runtime_error(self): - cfg = GoodputRecorder.default_config().set( - name="test", - upload_dir="/tmp/test", - upload_interval=1, - ) - recorder = GoodputRecorder(cfg) - - with mock.patch("jax.process_index", return_value=0): - with mock.patch( - "ml_goodput_measurement.goodput.GoodputRecorder" - ) as mock_recorder_cls, mock.patch.object(logging, "warning") as mock_warning: - mock_instance = mock_recorder_cls.return_value - - def raise_runtime_error(*args, **kwargs): - raise RuntimeError("mocked error") - - mock_instance.record_job_start_time.side_effect = raise_runtime_error - mock_instance.record_job_end_time.side_effect = raise_runtime_error - # Should not crash here. - with recorder.record_event(measurement.Event.JOB): - pass - - # Assert warnings were logged for start and end failures - assert mock_warning.call_count == 2 - start_call = mock_warning.call_args_list[0] - end_call = mock_warning.call_args_list[1] - - assert "Failed to record" in start_call.args[0] - assert "Failed to record" in end_call.args[0] - - @parameterized.parameters( - dict(is_pathways_job=False, mock_jax_backend="tpu"), - dict(is_pathways_job=True, mock_jax_backend="proxy"), - dict(is_pathways_job=False, mock_jax_backend=None), - ) - @mock.patch("jax.process_index", return_value=0) - def test_maybe_monitor_goodput(self, _, is_pathways_job, mock_jax_backend): - """Tests the _maybe_monitor_goodput context manager.""" - cfg = GoodputRecorder.default_config().set( - name="test-monitor", - upload_dir="/test", - upload_interval=30, - jax_backend=mock_jax_backend, ) - recorder = GoodputRecorder(cfg) - - with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_monitor_cls: - mock_monitor_instance = mock_monitor_cls.return_value - with recorder._maybe_monitor_goodput(): - pass - - # Verify that GoodputMonitor was instantiated with the correct parameters. - mock_monitor_cls.assert_called_once_with( - job_name="test-monitor", - logger_name="goodput_logger_test-monitor", - tensorboard_dir="/test", - upload_interval=30, - monitoring_enabled=True, - pathway_enabled=is_pathways_job, - include_badput_breakdown=True, - ) - mock_monitor_instance.start_goodput_uploader.assert_called_once() - mock_monitor_instance.stop_goodput_uploader.assert_called_once() - - @parameterized.parameters( - dict( - is_rolling_window_enabled=True, - rolling_window_size=[10, 20], - is_pathways_job=False, - mock_jax_backend="tpu", - ), - dict( - is_rolling_window_enabled=False, - rolling_window_size=[], - is_pathways_job=False, - mock_jax_backend="tpu", - ), - dict( - is_rolling_window_enabled=True, - rolling_window_size=[50], - is_pathways_job=True, - mock_jax_backend="proxy", - ), - ) - @mock.patch("jax.process_index", return_value=0) - def test_maybe_monitor_rolling_window( - self, - mock_process_index, - is_rolling_window_enabled, - rolling_window_size, - is_pathways_job, - mock_jax_backend, - ): # pylint: disable=unused-argument - """Tests the rolling window monitoring.""" - cfg = GoodputRecorder.default_config().set( - name="test-rolling", - upload_dir="/test", - upload_interval=30, - rolling_window_size=rolling_window_size, - jax_backend=mock_jax_backend, - ) - recorder = GoodputRecorder(cfg) - - with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_monitor_cls: - mock_monitor_instance = mock_monitor_cls.return_value - if not is_rolling_window_enabled: - with recorder._maybe_monitor_rolling_window_goodput(): - pass - mock_monitor_cls.assert_not_called() - return - with recorder._maybe_monitor_rolling_window_goodput(): - pass - - mock_monitor_cls.assert_called_once_with( - job_name="test-rolling", - logger_name="goodput_logger_test-rolling", - tensorboard_dir="/test/rolling_window_test-rolling", - upload_interval=30, - monitoring_enabled=True, - pathway_enabled=is_pathways_job, - include_badput_breakdown=True, - ) - - mock_monitor_instance.start_rolling_window_goodput_uploader.assert_called_with( - rolling_window_size - ) - mock_monitor_instance.stop_rolling_window_goodput_uploader.assert_called_once() + fv.mark_as_parsed() + + recorder = GoodputRecorder.from_flags(fv) + self.assertIsNone(recorder._monitor) # Ensure _monitor is initially None + + with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_goodput_monitor: + with mock.patch("ml_goodput_measurement.monitoring.GCPOptions") as mock_gcp_options: + mock_monitor_instance = mock_goodput_monitor.return_value + recorder.start_monitoring() + mock_gcp_options.assert_called_once_with( + enable_gcp_goodput_metrics=True, + enable_gcp_step_deviation_metrics=True, + ) + mock_gcp_options_instance = mock_gcp_options.return_value + + # Check that GoodputMonitor was instantiated + mock_goodput_monitor.assert_called_once_with( + job_name="test-name", + logger_name="goodput_logger_test-name", + tensorboard_dir="/test/path/to/upload", + upload_interval=15, + monitoring_enabled=True, + include_badput_breakdown=True, + include_step_deviation=True, + step_deviation_interval_seconds=30, + gcp_options=mock_gcp_options_instance, + ) - @mock.patch("jax.process_index", return_value=1) - def test_non_zero_process_index_skips_monitoring( - self, mock_process_index - ): # pylint: disable=unused-argument - """Tests that monitoring is skipped on non-zero process indices.""" - cfg = GoodputRecorder.default_config().set( - name="test", upload_dir="/test", upload_interval=30 + # Ensure that start_goodput_uploader and start_step_deviation_uploader is called on + # the monitor instance + mock_monitor_instance.start_goodput_uploader.assert_called_once() + mock_monitor_instance.start_step_deviation_uploader.assert_called_once() + self.assertIsNotNone(recorder._monitor) + + def test_missing_required_flags(self): + fv = flags.FlagValues() + measurement.define_flags(flag_values=fv) + # Missing 'upload_dir' and 'upload_interval' from recorder_spec + fv.set_default("recorder_spec", ["name=test-name"]) # Incomplete config + fv.mark_as_parsed() + + # Expecting ValueError since 'upload_dir' and 'upload_interval' are required + with self.assertRaises(ValueError): + GoodputRecorder.from_flags(fv) + + def test_monitoring_initialization_failure(self): + fv = flags.FlagValues() + measurement.define_flags(flag_values=fv) + fv.set_default( + "recorder_spec", + ["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"], ) - recorder = GoodputRecorder(cfg) - - with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_monitor_cls: - # Test cumulative goodput monitoring. - with recorder._maybe_monitor_goodput(): - pass - mock_monitor_cls.assert_not_called() - - cfg_rolling = GoodputRecorder.default_config().set( - name="test-rolling-skip", - upload_dir="/test", - upload_interval=30, - rolling_window_size=[10, 20], - ) - recorder_rolling = GoodputRecorder(cfg_rolling) - with recorder_rolling._maybe_monitor_rolling_window_goodput(): - pass - mock_monitor_cls.assert_not_called() - - @parameterized.parameters( - dict( - rolling_window_size=[5, 10], - jax_backend="tpu", - expected_monitor_calls=2, # Cumulative & Rolling Window - expect_rolling=True, - expect_cumulative=True, - ), - dict( - rolling_window_size=[], - jax_backend="tpu", - expected_monitor_calls=1, # Cumulative only - expect_rolling=False, - expect_cumulative=True, - ), - dict( - rolling_window_size=[5, 10], - jax_backend=None, # Disables Pathways - expected_monitor_calls=2, - expect_rolling=True, - expect_cumulative=True, - ), - dict( - rolling_window_size=[], - jax_backend=None, - expected_monitor_calls=1, - expect_rolling=False, - expect_cumulative=True, - ), - ) - @mock.patch("jax.process_index", return_value=0) - def test_maybe_monitor_all_goodput( - self, - _, - rolling_window_size, - jax_backend, - expected_monitor_calls, - expect_rolling, - expect_cumulative, - ): - """Tests all goodput monitoring with various configs.""" - cfg = GoodputRecorder.default_config().set( - name="test-all", - upload_dir="/test", - upload_interval=30, - rolling_window_size=rolling_window_size, - jax_backend=jax_backend, + fv.mark_as_parsed() + + recorder = GoodputRecorder.from_flags(fv) + self.assertIsNone(recorder._monitor) + + # Mock a failure in initializing the GoodputMonitor + with mock.patch( + "ml_goodput_measurement.monitoring.GoodputMonitor", + side_effect=Exception("Failed to initialize GoodputMonitor"), + ): + with self.assertRaises(Exception): + recorder.start_monitoring() + self.assertIsNone(recorder._monitor) + + def test_non_zero_process_index(self): + fv = flags.FlagValues() + measurement.define_flags(flag_values=fv) + fv.set_default( + "recorder_spec", + ["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"], ) - recorder = GoodputRecorder(cfg) - - with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_monitor_cls: - mock_monitor_instance = mock_monitor_cls.return_value + fv.mark_as_parsed() - with recorder.maybe_monitor_all_goodput(): - pass + recorder = GoodputRecorder.from_flags(fv) + self.assertIsNone(recorder._monitor) - self.assertEqual(mock_monitor_cls.call_count, expected_monitor_calls) + with mock.patch("jax.process_index") as mock_process_index: + mock_process_index.return_value = 1 # Simulate a non-zero process index - if expect_cumulative: - mock_monitor_instance.start_goodput_uploader.assert_called_once() - mock_monitor_instance.stop_goodput_uploader.assert_called_once() - else: - mock_monitor_instance.start_goodput_uploader.assert_not_called() - mock_monitor_instance.stop_goodput_uploader.assert_not_called() - - if expect_rolling: - mock_monitor_instance.start_rolling_window_goodput_uploader.assert_called_once_with( - rolling_window_size - ) - mock_monitor_instance.stop_rolling_window_goodput_uploader.assert_called_once() - else: - mock_monitor_instance.start_rolling_window_goodput_uploader.assert_not_called() - mock_monitor_instance.stop_rolling_window_goodput_uploader.assert_not_called() + try: + recorder.start_monitoring() + except AttributeError: + self.fail("AttributeError was raised unexpectedly.") diff --git a/axlearn/common/launch_trainer.py b/axlearn/common/launch_trainer.py index 7470ad66c..bba28533e 100644 --- a/axlearn/common/launch_trainer.py +++ b/axlearn/common/launch_trainer.py @@ -2,7 +2,6 @@ """Utilities to launch a trainer.""" -import contextlib import json import os from typing import Any, Optional @@ -129,8 +128,8 @@ def get_trainer_config( return trainer_config -def _run_trainer_impl(trainer_config: SpmdTrainer.Config) -> Any: - """Instantiates and runs the trainer.""" +def run_trainer(trainer_config: SpmdTrainer.Config) -> Any: + measurement.record_event(measurement.Event.START_JOB) trainer_config_debug_string = trainer_config.debug_string() logging.info("Trainer config:\n%s", trainer_config_debug_string) if jax.process_index() == 0: @@ -150,13 +149,6 @@ def _run_trainer_impl(trainer_config: SpmdTrainer.Config) -> Any: trainer: SpmdTrainer = trainer_config.instantiate(parent=None) prng_key = jax.random.PRNGKey(seed=FLAGS.trainer_prng_seed) - return trainer.run(prng_key) - - -def run_trainer(trainer_config: SpmdTrainer.Config) -> Any: - recorder = measurement.global_recorder - job_events_manager = ( - recorder.record_event(measurement.Event.JOB) if recorder else contextlib.nullcontext() - ) - with job_events_manager: - return _run_trainer_impl(trainer_config) + output = trainer.run(prng_key) + measurement.record_event(measurement.Event.END_JOB) + return output diff --git a/axlearn/common/launch_trainer_main.py b/axlearn/common/launch_trainer_main.py index 8d170a950..2f617b4cd 100644 --- a/axlearn/common/launch_trainer_main.py +++ b/axlearn/common/launch_trainer_main.py @@ -13,6 +13,7 @@ def main(_): launch.setup() trainer_config = launch_trainer.get_trainer_config() trainer_config.set(recorder=config_for_function(lambda: measurement.global_recorder)) + measurement.start_monitoring() launch_trainer.run_trainer(trainer_config) diff --git a/axlearn/common/measurement.py b/axlearn/common/measurement.py index 1d2a9dea7..b0a40a85f 100644 --- a/axlearn/common/measurement.py +++ b/axlearn/common/measurement.py @@ -2,7 +2,6 @@ """A library to measure e2e metrics like goodput.""" -import contextlib import enum import importlib from typing import Optional, TypeVar @@ -16,20 +15,30 @@ class Event(enum.Enum): """Event to be recorded. Attributes: - JOB: Start and end of the job. - STEP: Start of a training step. Should be recorded with `step` as a positional arg. - ACCELERATOR_INIT: Start and end of accelerator mesh initialization. - TRAINING_PREPARATION: Start and end of training preparation. - DATA_LOADING: Start and end of data loading. - CUSTOM_BADPUT_EVENT: Start and end of custom badput events. + START_JOB: Start of job. + END_JOB: End of job. + START_STEP: Start of a training step. Should be recorded with `step` as a positional arg. + START_ACCELERATOR_INIT: Start of accelerator mesh initialization. + END_ACCELERATOR_INIT: End of accelerator mesh initialization. + START_TRAINING_PREPARATION: Start of training preparation. + END_TRAINING_PREPARATION: End of training preparation. + START_DATA_LOADING: Start of data loading. + END_DATA_LOADING: End of data loading. + START_CUSTOM_BADPUT_EVENT: Start of custom badput event. + END_CUSTOM_BADPUT_EVENT: End of custom badput event. """ - JOB = "job" - STEP = "step" - ACCELERATOR_INIT = "tpu_init" - TRAINING_PREPARATION = "training_preparation" - DATA_LOADING = "data_loading" - CUSTOM_BADPUT_EVENT = "custom_badput_event" + START_JOB = "START_JOB" + END_JOB = "END_JOB" + START_STEP = "START_STEP" + START_ACCELERATOR_INIT = "START_ACCELERATOR_INIT" + END_ACCELERATOR_INIT = "END_ACCELERATOR_INIT" + START_TRAINING_PREPARATION = "START_TRAINING_PREPARATION" + END_TRAINING_PREPARATION = "END_TRAINING_PREPARATION" + START_DATA_LOADING = "START_DATA_LOADING" + END_DATA_LOADING = "END_DATA_LOADING" + START_CUSTOM_BADPUT_EVENT = "START_CUSTOM_BADPUT_EVENT" + END_CUSTOM_BADPUT_EVENT = "END_CUSTOM_BADPUT_EVENT" class Recorder(Configurable): @@ -50,15 +59,9 @@ def from_flags(cls, fv: Optional[flags.FlagValues]) -> "Recorder": """Converts flags to a recorder.""" raise NotImplementedError(cls) - @contextlib.contextmanager - def record_event(self, event: Event, *args, **kwargs): - """A context manager to record the start and end of an event.""" - # pylint: disable=unnecessary-pass - # pylint: disable=unused-argument - try: - yield - finally: - pass + def record(self, event: Event, *args, **kwargs): + """Records an event with the given name.""" + raise NotImplementedError(type(self)) def start_monitoring(self, **kwargs): """Starts computing and uploading metrics at some configured interval in the background.""" @@ -131,6 +134,14 @@ def initialize(fv: flags.FlagValues): ) +def record_event(event: Event): + """Records a global event.""" + if global_recorder is None: + logging.log_first_n(logging.INFO, "No recorder configured, ignoring events.", 1) + else: + global_recorder.record(event) + + def start_monitoring(): """Begins monitoring events as per global monitor functionality.""" if global_recorder is None: diff --git a/axlearn/common/measurement_test.py b/axlearn/common/measurement_test.py index d36605f29..c9043f20b 100644 --- a/axlearn/common/measurement_test.py +++ b/axlearn/common/measurement_test.py @@ -3,30 +3,24 @@ """Tests measurement utils.""" # pylint: disable=protected-access -import contextlib from unittest import mock from absl import flags from absl.testing import parameterized from axlearn.common import measurement -from axlearn.experiments.testdata.axlearn_common_measurement_test.dummy_recorder import ( - DummyRecorder as RealDummyRecorder, -) class UtilsTest(parameterized.TestCase): """Tests utils.""" def setUp(self): - super().setUp() self._orig_recorder = measurement.global_recorder - self._orig_recorders = measurement._recorders.copy() + self._orig_recorders = measurement._recorders measurement.global_recorder = None measurement._recorders = {} def tearDown(self): - super().tearDown() measurement.global_recorder = self._orig_recorder measurement._recorders = self._orig_recorders @@ -39,25 +33,32 @@ class DummyRecorder(measurement.Recorder): self.assertEqual(DummyRecorder, measurement._recorders.get("test")) + # Registering twice should fail. with self.assertRaisesRegex(ValueError, "already registered"): measurement.register_recorder("test")(DummyRecorder) @parameterized.parameters( - dict(recorder_type=None), - dict(recorder_type="test"), + # No-op if no recorder_type provided. + dict( + recorder_type=None, + expected=None, + ), + dict( + recorder_type="test", + expected="Mock", + ), + # Try initializing from another module. dict( recorder_type=( - "axlearn.experiments.testdata.axlearn_common_measurement_test.dummy_recorder:" + f"axlearn.experiments.testdata.{__name__.replace('.', '_')}.dummy_recorder:" "dummy_recorder" - ) + ), + expected="DummyRecorder", ), ) - def test_initialize(self, recorder_type): - mock_recorder_cls = mock.MagicMock() - mock_recorder_instance = mock_recorder_cls.from_flags.return_value - mock_recorder_instance.record_event.return_value = contextlib.nullcontext() - measurement.register_recorder("test")(mock_recorder_cls) - measurement.register_recorder("dummy_recorder")(RealDummyRecorder) + def test_initialize(self, recorder_type, expected): + mock_recorder = mock.MagicMock() + measurement.register_recorder("test")(mock_recorder) fv = flags.FlagValues() measurement.define_flags(flag_values=fv) @@ -68,17 +69,24 @@ def test_initialize(self, recorder_type): measurement.initialize(fv) if recorder_type is None: + # global_recorder should not be initialized, and record_event should be no-op. self.assertIsNone(measurement.global_recorder) + measurement.record_event(measurement.Event.START_JOB) return recorder_name = recorder_type.split(":", 1)[-1] if recorder_name == "test": - self.assertEqual(mock_recorder_instance, measurement.global_recorder) - mock_recorder_cls.from_flags.assert_called_once() - elif recorder_name == "dummy_recorder": - self.assertIsNotNone(measurement.global_recorder) - self.assertIsInstance(measurement.global_recorder, RealDummyRecorder) + self.assertTrue(mock_recorder.from_flags.called) + + self.assertIn(expected, str(measurement._recorders.get(recorder_name, None))) + self.assertIn(expected, str(measurement.global_recorder)) + + # Ensure that record_event does not fail. + with mock.patch.object(measurement.global_recorder, "record") as mock_record: + measurement.record_event(measurement.Event.START_JOB) + self.assertIn(measurement.Event.START_JOB, mock_record.call_args[0]) + # Ensure that start_monitoring does not fail. with mock.patch.object( measurement.global_recorder, "start_monitoring" ) as mock_start_monitoring: diff --git a/axlearn/common/trainer.py b/axlearn/common/trainer.py index 4332d3a8e..692c2e24a 100644 --- a/axlearn/common/trainer.py +++ b/axlearn/common/trainer.py @@ -241,121 +241,116 @@ def __init__( self._device_monitor = maybe_instantiate(cfg.device_monitor) self._recorder = maybe_instantiate(cfg.recorder) self._is_initialized: bool = False - # Accelerator initialization. - with self._record_event(measurement.Event.ACCELERATOR_INIT): - if cfg.model.dtype is None: - raise ValueError(f"dtype must be explicitly specified for {self.path()}.model") - if cfg.model.param_init is None: - cfg.model.param_init = DefaultInitializer.default_config() - logging.info( - "model.param_init is not specified. Default to DefaultInitializer: %s", - cfg.model.param_init, - ) + self._maybe_record_event(measurement.Event.START_ACCELERATOR_INIT) - self._per_param_train_dtype = maybe_instantiate( - canonicalize_per_param_dtype(cfg.train_dtype) + if cfg.model.dtype is None: + raise ValueError(f"dtype must be explicitly specified for {self.path()}.model") + if cfg.model.param_init is None: + cfg.model.param_init = DefaultInitializer.default_config() + logging.info( + "model.param_init is not specified. Default to DefaultInitializer: %s", + cfg.model.param_init, ) - # Create the device mesh. - if devices is None: - self._step_log( - "Devices: global=%s local=%s %s", - jax.device_count(), - jax.local_device_count(), - [device.platform for device in jax.local_devices()], + self._per_param_train_dtype = maybe_instantiate( + canonicalize_per_param_dtype(cfg.train_dtype) + ) + + # Create the device mesh. + if devices is None: + self._step_log( + "Devices: global=%s local=%s %s", + jax.device_count(), + jax.local_device_count(), + [device.platform for device in jax.local_devices()], + ) + else: + local_devices = [d for d in devices.flatten() if d.process_index == jax.process_index()] + self._step_log( + "Devices: global=%s local=%s %s", + len(devices), + len(local_devices), + [device.platform for device in local_devices], + ) + self._step_log("Mesh shape: %s", cfg.mesh_shape) + devices = ( + utils.create_device_mesh(mesh_shape=cfg.mesh_shape) if devices is None else devices + ) + mesh = jax.sharding.Mesh(devices, cfg.mesh_axis_names) + self._step_log("Global mesh: %s", mesh) + self._mesh = mesh + self._context_manager: Callable[[], ContextManager] = ( + maybe_instantiate(cfg.context_manager) or contextlib.nullcontext + ) + xsc_check_policy = None + if cfg.xsc_check_policy: + if jax.default_backend() != "tpu": + # XSC is currently only supported on TPU XLA backend. + logging.warning( + "xsc_check_policy was set for non-TPU XLA backend. Running without XSC." ) else: - local_devices = [ - d for d in devices.flatten() if d.process_index == jax.process_index() - ] - self._step_log( - "Devices: global=%s local=%s %s", - len(devices), - len(local_devices), - [device.platform for device in local_devices], - ) - self._step_log("Mesh shape: %s", cfg.mesh_shape) - devices = ( - utils.create_device_mesh(mesh_shape=cfg.mesh_shape) if devices is None else devices + xsc_check_policy = maybe_instantiate(cfg.xsc_check_policy) + self._xsc_check_policy: Optional[Callable[[int], bool]] = xsc_check_policy + self._compiled_train_step: Optional[jax.stages.Compiled] = None + + # Create all children within the mesh context so that utils.input_partition_spec() works + # properly. + with self.mesh(): + self.input: Input = self._add_child( + "input", + maybe_set_config( + cfg.input, partition_spec=PartitionSpec(cfg.batch_axis_names), is_training=True + ), ) - mesh = jax.sharding.Mesh(devices, cfg.mesh_axis_names) - self._step_log("Global mesh: %s", mesh) - self._mesh = mesh - self._context_manager: Callable[[], ContextManager] = ( - maybe_instantiate(cfg.context_manager) or contextlib.nullcontext + # Start from the beginning of the input dataset by default. + self._input_iter = iter(self.input.dataset()) + cfg.summary_writer.dir = cfg.summary_writer.dir or os.path.join( + cfg.dir, "summaries", "train_train" ) - xsc_check_policy = None - if cfg.xsc_check_policy: - if jax.default_backend() != "tpu": - # XSC is currently only supported on TPU XLA backend. - logging.warning( - "xsc_check_policy was set for non-TPU XLA backend. Running without XSC." - ) - else: - xsc_check_policy = maybe_instantiate(cfg.xsc_check_policy) - self._xsc_check_policy: Optional[Callable[[int], bool]] = xsc_check_policy - self._compiled_train_step: Optional[jax.stages.Compiled] = None - - # Create all children within the mesh context so that utils.input_partition_spec() works - # properly. - with self.mesh(): - if cfg.batch_axis_names is not None: - cfg.input = maybe_set_config( - cfg.input, partition_spec=PartitionSpec(cfg.batch_axis_names) - ) - self.input: Input = self._add_child( - "input", maybe_set_config(cfg.input, is_training=True) - ) - # Start from the beginning of the input dataset by default. - self._input_iter = iter(self.input.dataset()) - cfg.summary_writer.dir = cfg.summary_writer.dir or os.path.join( - cfg.dir, "summaries", "train_train" - ) - self._add_child("summary_writer", cfg.summary_writer) - self._add_child("model", cfg.model) - self._add_child("learner", cfg.learner) - cfg.checkpointer.dir = cfg.checkpointer.dir or os.path.join(cfg.dir, "checkpoints") - self._add_child("checkpointer", cfg.checkpointer) - if cfg.init_state_builder is not None: - self._add_child("init_state_builder", cfg.init_state_builder) - - self._model_param_specs = self.model.create_parameter_specs_recursively() - model_param_partition_specs = jax.tree.map( - lambda spec: spec.mesh_axes, self._model_param_specs - ) - for name, spec in utils.flatten_items(self._model_param_specs): - self._step_log("Model param spec: %s=%s", name, spec) - self._learner_state_partition_specs = self.learner.create_state_partition_specs( - self._model_param_specs + self._add_child("summary_writer", cfg.summary_writer) + self._add_child("model", cfg.model) + self._add_child("learner", cfg.learner) + cfg.checkpointer.dir = cfg.checkpointer.dir or os.path.join(cfg.dir, "checkpoints") + self._add_child("checkpointer", cfg.checkpointer) + if cfg.init_state_builder is not None: + self._add_child("init_state_builder", cfg.init_state_builder) + + self._model_param_specs = self.model.create_parameter_specs_recursively() + model_param_partition_specs = jax.tree.map( + lambda spec: spec.mesh_axes, self._model_param_specs + ) + for name, spec in utils.flatten_items(self._model_param_specs): + self._step_log("Model param spec: %s=%s", name, spec) + self._learner_state_partition_specs = self.learner.create_state_partition_specs( + self._model_param_specs + ) + for name, spec in utils.flatten_items(self._learner_state_partition_specs): + self._step_log("Learner state spec: %s=%s", name, spec) + self._trainer_state_specs = TrainerState( + prng_key=ParameterSpec(dtype=jnp.uint32, shape=[4], mesh_axes=PartitionSpec(None)), + model=self._model_param_specs, + learner=self._learner_state_partition_specs, + ) + self._trainer_state_partition_specs: TrainerState = jax.tree.map( + lambda spec: spec.sharding, self._trainer_state_specs + ) + # Create evalers, which depend on model_param_partition_specs. + self._evalers = {} + for evaler_name, evaler_cfg in cfg.evalers.items(): + evaler_cfg.summary_writer.dir = evaler_cfg.summary_writer.dir or os.path.join( + cfg.dir, "summaries", evaler_name ) - for name, spec in utils.flatten_items(self._learner_state_partition_specs): - self._step_log("Learner state spec: %s=%s", name, spec) - self._trainer_state_specs = TrainerState( - prng_key=ParameterSpec( - dtype=jnp.uint32, shape=[4], mesh_axes=PartitionSpec(None) - ), - model=self._model_param_specs, - learner=self._learner_state_partition_specs, + maybe_set_config( + evaler_cfg.input, partition_spec=PartitionSpec(cfg.batch_axis_names) ) - self._trainer_state_partition_specs: TrainerState = jax.tree.map( - lambda spec: spec.sharding, self._trainer_state_specs + self._evalers[evaler_name] = self._add_child( + evaler_name, + evaler_cfg, + model=self.model, + model_param_partition_specs=model_param_partition_specs, ) - # Create evalers, which depend on model_param_partition_specs. - self._evalers = {} - for evaler_name, evaler_cfg in cfg.evalers.items(): - evaler_cfg.summary_writer.dir = evaler_cfg.summary_writer.dir or os.path.join( - cfg.dir, "summaries", evaler_name - ) - if cfg.batch_axis_names is not None: - maybe_set_config( - evaler_cfg.input, partition_spec=PartitionSpec(cfg.batch_axis_names) - ) - self._evalers[evaler_name] = self._add_child( - evaler_name, - evaler_cfg, - model=self.model, - model_param_partition_specs=model_param_partition_specs, - ) + self._maybe_record_event(measurement.Event.END_ACCELERATOR_INIT) @property def step(self): @@ -373,15 +368,6 @@ def trainer_state_specs(self): def trainer_state_partition_specs(self): return self._trainer_state_partition_specs - @contextlib.contextmanager - def _record_event(self, event: measurement.Event, *args, **kwargs): - """A helper to record an event if a recorder is configured.""" - if self._recorder: - with self._recorder.record_event(event, *args, **kwargs) as event_manager: - yield event_manager - else: - yield - def _train_step_input_partition_specs(self): # Note that subclasses may override this method to set a partition spec for pjit which is # different from that of the input partition spec. @@ -539,6 +525,10 @@ def _should_force_run_evals( ) return force_run_evals + def _maybe_record_event(self, event: measurement.Event, *args, **kwargs): + if self._recorder is not None: + self._recorder.record(event, *args, **kwargs) + # pylint: disable-next=too-many-statements,too-many-branches def run( self, prng_key: Tensor, *, return_evaler_summaries: Optional[Union[bool, set[str]]] = None @@ -564,7 +554,6 @@ def run( different types of values such as WeightedScalar, Tensor, or string, depending on the specific `metric_calculator` config of the evaler. """ - with ( ( self._device_monitor.start_monitoring() @@ -575,7 +564,6 @@ def run( self.mesh(), jax.log_compiles(self.vlog_is_on(1)), self._context_manager(), - self._recorder.maybe_monitor_all_goodput(), ): cfg = self.config # Check if need to force run evals at the last training step. @@ -584,9 +572,8 @@ def run( ) # Prepare training. - with self._record_event(measurement.Event.TRAINING_PREPARATION): - if not self._prepare_training(prng_key): - return None + if not self._prepare_training(prng_key): + return None self._is_initialized = True @@ -599,10 +586,10 @@ def run( input_iterator = self.input.batches(self._input_iter) while True: + self._maybe_record_event(measurement.Event.START_DATA_LOADING) try: - with self._record_event(measurement.Event.DATA_LOADING): - input_batch = next(input_iterator) - + input_batch = next(input_iterator) + self._maybe_record_event(measurement.Event.END_DATA_LOADING) logging.log_first_n( logging.INFO, "input_batch=%s", 3, utils.shapes(input_batch) ) @@ -612,18 +599,18 @@ def run( self._step = self._step + 1 self.vlog(3, "Start step %s", self.step) - with self._record_event(measurement.Event.STEP, self._step): - output = self._run_step( - utils.host_to_global_array( - input_batch, - partition=self._train_step_input_partition_specs(), - ), - force_run_evals=( - force_run_eval_sets_at_max_step - if self.step >= cfg.max_step - else None - ), - ) + self._maybe_record_event(measurement.Event.START_STEP, self._step) + output = self._run_step( + utils.host_to_global_device_array( + input_batch, + partition=self._train_step_input_partition_specs(), + ), + force_run_evals=( + force_run_eval_sets_at_max_step + if self.step >= cfg.max_step + else None + ), + ) self.vlog(3, "Done step %s", self.step) num_steps += 1 if num_steps % 1 == 0: @@ -637,6 +624,9 @@ def run( self._step_log("Reached max_step=%s. Stopping", cfg.max_step) break except StopIteration: + # Add END_DATA_LOADING event here to close the unpaired START_DATA_LOADING + # event. + self._maybe_record_event(measurement.Event.END_DATA_LOADING) break if self.step < cfg.max_step: self._step_log("Reached end of inputs. Stopping") @@ -877,6 +867,7 @@ def _prepare_training(self, prng_key: Tensor) -> bool: A boolean indicating whether the model training should start. If not, return None from the `run` function. """ + self._maybe_record_event(measurement.Event.START_TRAINING_PREPARATION) cfg = self.config # Attempt to restore the latest checkpoint, which may contain a saved `_input_iter`. @@ -909,6 +900,7 @@ def _prepare_training(self, prng_key: Tensor) -> bool: return False self._jit_train_step = self._pjit_train_step() + self._maybe_record_event(measurement.Event.END_TRAINING_PREPARATION) return True def restore_checkpoint(self, restore_step: Optional[int] = None) -> Optional[int]: @@ -1047,29 +1039,36 @@ def _get_compiled_train_step_fn( mesh_shape=cfg.mesh_shape, mesh_axis_names=cfg.mesh_axis_names, device_kind=device_kind ) if not with_xsc: - with self._record_event( - measurement.Event.CUSTOM_BADPUT_EVENT, + self._maybe_record_event( + measurement.Event.START_CUSTOM_BADPUT_EVENT, custom_badput_event_type="COMPILATION_NO_XSC", - ): - self._compiled_train_step = self.compile_train_step( - trainer_state=trainer_state, input_batch=input_batch, compiler_options=options - ) + ) + self._compiled_train_step = self.compile_train_step( + trainer_state=trainer_state, input_batch=input_batch, compiler_options=options + ) + self._maybe_record_event( + measurement.Event.END_CUSTOM_BADPUT_EVENT, + custom_badput_event_type="COMPILATION_NO_XSC", + ) return self._compiled_train_step - logging.log_first_n(logging.INFO, "Compiling XSC train step.", 1) - with self._record_event( - measurement.Event.CUSTOM_BADPUT_EVENT, + self._maybe_record_event( + measurement.Event.START_CUSTOM_BADPUT_EVENT, custom_badput_event_type="COMPILATION_WITH_XSC", - ): - compiled_jit_train_step_fn = self.compile_train_step( - trainer_state=trainer_state, - input_batch=input_batch, - compiler_options=options - | infer_xsc_compiler_options( - halt_on_detection=True, repeat_count=1, device_kind=device_kind - ), - ) + ) + compiled_jit_train_step_fn = self.compile_train_step( + trainer_state=trainer_state, + input_batch=input_batch, + compiler_options=options + | infer_xsc_compiler_options( + halt_on_detection=True, repeat_count=1, device_kind=device_kind + ), + ) + self._maybe_record_event( + measurement.Event.END_CUSTOM_BADPUT_EVENT, + custom_badput_event_type="COMPILATION_WITH_XSC", + ) return compiled_jit_train_step_fn def _run_step( @@ -1126,23 +1125,26 @@ def _run_eval( force_runs: Optional[set[str]] = None, ) -> dict[str, Any]: """Runs evaluations and returns the corresponding summaries.""" - with self._record_event( - measurement.Event.CUSTOM_BADPUT_EVENT, custom_badput_event_type="EVAL" - ): - evaler_summaries = {} - # Note: we will use the same eval key as the training keys of the future step, - # which should be okay. - prng_key = self._trainer_state.prng_key - for evaler_name, evaler in self._evalers.items(): - prng_key, summaries, _ = evaler.eval_step( - self.step, - prng_key=prng_key, - model_params=self.model_params_for_eval(), - train_summaries=train_summaries, - force_run=bool(force_runs is not None and evaler_name in force_runs), - ) - evaler_summaries[evaler_name] = summaries - return evaler_summaries + self._maybe_record_event( + measurement.Event.START_CUSTOM_BADPUT_EVENT, custom_badput_event_type="EVAL" + ) + evaler_summaries = {} + # Note: we will use the same eval key as the training keys of the future step, + # which should be okay. + prng_key = self._trainer_state.prng_key + for evaler_name, evaler in self._evalers.items(): + prng_key, summaries, _ = evaler.eval_step( + self.step, + prng_key=prng_key, + model_params=self.model_params_for_eval(), + train_summaries=train_summaries, + force_run=bool(force_runs is not None and evaler_name in force_runs), + ) + evaler_summaries[evaler_name] = summaries + self._maybe_record_event( + measurement.Event.END_CUSTOM_BADPUT_EVENT, custom_badput_event_type="EVAL" + ) + return evaler_summaries def _pjit_train_step(self) -> jax.stages.Wrapped: return pjit( diff --git a/docs/05-Goodput-Monitoring.md b/docs/05-Goodput-Monitoring.md index cb17f6989..ca1452c19 100644 --- a/docs/05-Goodput-Monitoring.md +++ b/docs/05-Goodput-Monitoring.md @@ -1,14 +1,10 @@ # ML Goodput Monitoring -AXLearn supports automatic measurement and upload of a wide range of workload -metrics using the **ML Goodput Measurement** library. This includes: -* **Goodput** and **Badput Breakdown** -* **Step Metrics** (Ideal Step Time, Step Time Deviation, Last Productive Step etc.) -* **Workload Hang Metrics** (Disruption Count, Step Info) -* **Rolling Window Goodput & Badput Breakdown** +AXLearn supports automatic measurement and upload of workload metrics such as +Goodput, Badput Breakdown and Step Time Deviation using the ML Goodput +Measurement library. The [ML Goodput Measurement](https://github.com/AI-Hypercomputer/ml-goodput-measurement) library currently supports monitoring workloads running on Google Cloud Platform. For more information on details of the library, visit the Github page or the [ml-goodput-measurement](https://pypi.org/project/ml-goodput-measurement/) PyPI package documentation. - ### What is Goodput Goodput is the metric that measures the efficiency of model training jobs, i.e. productive time spent on training progress proportional to the total time spent @@ -19,26 +15,12 @@ improve to get the most value from their accelerators. Badput is the metric that measures time that a workload spent on anything that is not productive training proportional to the total time spent by the workload. For example, the time spent in accelerator initialization, training preparation, -program startup, data loading, portions of checkpointing, recovering from -disruptions, wasted progress since the last checkpoint etc. all contribute to Badput. - -The ML Goodput Measurement library exposes Badput Breakdown. Further details of -each bucket can be found [here](https://github.com/AI-Hypercomputer/ml-goodput-measurement?tab=readme-ov-file#badput-breakdown-details) - -## What is Rolling Window Goodput & Badput -The ML Goodput Measurement library allows users to monitor goodput and badput -breakdown metrics within specific, moving time windows. You can specify a list -of rolling window interval sizes in seconds, and the library will asynchronously -query and upload metrics calculated only within the context of those windows. -This is useful for understanding workload performance over recent, specific -durations (e.g., the last 24 hours). +program startup, data loading, portions of checkpointing, disruptions and +wasted progress since the last checkpoint etc. all contribute to Badput. -If the workload's actual runtime timeline is shorter than a requested window size, -the entire runtime timeline of the workload is used for the metrics computation. +The ML Goodput Measurement library exposes Badput Breakdown. Further details of each bucket can be found [here](https://github.com/AI-Hypercomputer/ml-goodput-measurement?tab=readme-ov-file#badput-breakdown-details) -> **Note**: Both the standard (cumulative) and rolling window query APIs can be enabled simultaneously to get a complete picture of your workload's performance. - -### What are Ideal Step Time and Step Time Deviation +### What is Step Time Deviation Step Time Deviation is the metric that measures deviation of step time (in seconds) from ideal step time. It is the difference between the actual time @@ -51,8 +33,8 @@ The formula for step deviation is: Ideal step time is equal to the user-configured `ideal_step_time` if it is provided. If the user has not specified an ideal step time, then the ideal step -time is calculated as a weighted average of the "normal" step times recorded for -the workload, where a "normal" step is defined as having a duration less than or +time is calculated as the average of the "normal" step times recorded for the +workload, where a "normal" step is defined as having a duration less than or equal to `median + median absolute deviation * 3` of the sample space of step times. This computation requires at least 10 recorded steps. @@ -95,7 +77,7 @@ project, then do the following: Please use a unique workload name, unless you intend to monitor cumulative Goodput/Badput metrics of a previous workload along with your current workload. -### How to Monitor Cumulative Goodput Metrics +### How to Monitor Goodput and Badput To enable Goodput recording and monitoring on AXLearn, follow the example below. @@ -112,22 +94,24 @@ To enable Goodput recording and monitoring on AXLearn, follow the example below. --recorder_spec=upload_interval=30 \ ``` -### How to Monitor Rolling Window Goodput Metrics +### How to Monitor Step Time Deviation -To enable rolling window metrics, set `enable_rolling_window_goodput_monitoring` to `True` -and provide a list of interval sizes for `rolling_window_size` in seconds: +AXLearn enables step time deviation monitoring by default. You can configure +the upload frequency by setting +`--recorder_spec=step_deviation_interval_seconds=30`. To disable step deviation +set `--recorder_spec=step_deviation_interval_seconds=-1`. ```bash -axlearn gcp launch run --instance_type=tpu-v5litepod-16 \ + axlearn gcp launch run --instance_type=tpu-v5litepod-16 \ --bundler_type=artifactregistry --bundler_spec=image=tpu \ --bundler_spec=dockerfile=Dockerfile \ - -- python3 -m my_training_job \ + --name= \ + -- python3 -m ...training-config... \ --recorder_type=axlearn.cloud.gcp.measurement:goodput \ --recorder_spec=name= \ --recorder_spec=upload_dir=my-output-directory/summaries \ --recorder_spec=upload_interval=30 \ - --recorder_spec=enable_rolling_window_goodput_monitoring=True \ - --recorder_spec=rolling_window_size=86400,259200,432000 + --recorder_spec=step_deviation_interval_seconds=30 \ ``` ### Visualize on Tensorboard @@ -137,16 +121,12 @@ axlearn gcp launch run --instance_type=tpu-v5litepod-16 \ ### Enabling Google Cloud Monitoring -By default, when Goodput monitoring is enabled via the recorder, AXLearn automatically pushes metrics to Google Cloud Monitoring. - -- **Cumulative Metrics** are enabled by default when you specify the `recorder_type`. - To disable this, you would need to set `enable_gcp_goodput_metrics` to `False` in - `goodput_monitoring.GCPOptions` within the `cloud/gcp/measurement.py` file. -- **Rolling Window Metrics** can be explicitly enabled by setting - `enable_rolling_window_goodput_monitoring` to `True` and providing window sizes - via `rolling_window_size`. - -You can enable either cumulative monitoring, rolling window monitoring, or both simultaneously. +AXLearn has an additional option of pushing goodput, badput and step time +deviation metrics to Google Cloud Monitoring. By default if goodput monitoring +is enabled, the data gets published to Google Cloud Monitoring. Set the variables +`enable_gcp_goodput_metrics` and `enable_gcp_step_deviation_metrics` to `False` in +`goodput_monitoring.GCPOptions` in `cloud/gcp/measurement.py` to disable goodput and step_deviation +uploads to GCM respectively. ```bash axlearn gcp launch run --instance_type=tpu-v5litepod-16 \ @@ -158,8 +138,7 @@ You can enable either cumulative monitoring, rolling window monitoring, or both --recorder_spec=name= \ --recorder_spec=upload_dir=my-output-directory/summaries \ --recorder_spec=upload_interval=30 \ - --recorder_spec=enable_rolling_window_goodput_monitoring=True \ - --recorder_spec=rolling_window_size=86400,604800 + --recorder_spec=step_deviation_interval_seconds=30 \ ``` #### Visualization in Google Cloud Monitoring @@ -180,38 +159,3 @@ To visualize the collected metrics within Google Cloud Monitoring: c. [**Performance:**](https://cloud.google.com/monitoring/api/metrics_gcp#:~:text=workload/performance) Represents the workload's performance metric, specifically step deviation in this context, measured by `compute.googleapis.com/workload/performance`. - -#### Google Cloud Monitoring Dashboard: Goodput Monitor - -Following are instructions for deploying a custom dashboard `goodput_dashboard.json` -to your Google Cloud project's Monitoring console. This dashboard -offers a comprehensive view of "Goodput" metrics, helping you monitor the -your workloads and set up custom alerts for "events" such as performance degradation. - - -#### Deployment Steps - -Follow these steps to create a new custom dashboard using the provided JSON -configuration: - -1. **Navigate to the Monitoring Console**: In your Google Cloud project, - go to the **Monitoring** section. From the left-hand navigation menu, - select **Dashboards**. - -2. **Create Custom Dashboard**: Click the **Create Custom Dashboard** button. - -3. **Use JSON Editor**: In the new dashboard interface, select the - **JSON editor** option. - -4. **Copy and Save Configuration**: Open the [goodput_dashboard.json](https://github.com/AI-Hypercomputer/ml-goodput-measurement/blob/main/ml_goodput_measurement/dashboards/goodput_dashboard.json) file. - Copy its entire content and paste it into the JSON editor. Once pasted, - click **Save**. - - -Your "Goodput Monitor" dashboard should now be visible and operational within -your custom dashboards list. - -> **_NOTE:_** This dashboard is intended to be a starting point for your -> monitoring needs. We recommend customizing it to meet your specific needs. -> Please refer to the [Monitoring Dashboard documentation](https://cloud.google.com/monitoring/dashboards) -> for further guidance and customization options. diff --git a/pyproject.toml b/pyproject.toml index dfdc377d9..5e55b5cda 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,7 +99,7 @@ gcp = [ "google-cloud-compute==1.19.2", # Needed for region discovery for CloudBuild API access. "google-cloud-core==2.3.3", "google-cloud-build==3.24.1", - "ml-goodput-measurement==0.0.13", + "ml-goodput-measurement==0.0.10", "pika==1.3.2", # used by event queue "pyOpenSSL>=22.1.0", # compat with cryptography version. "tpu-info==0.2.0", # For TPU monitoring from libtpu. https://github.com/AI-Hypercomputer/cloud-accelerator-diagnostics/tree/main/tpu_info From f888f903dabb11368033585ff019c20945dd31da Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Wed, 6 Aug 2025 13:32:07 -0700 Subject: [PATCH 44/57] print jax_devices --- axlearn/common/checkpointer_orbax_emergency.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/axlearn/common/checkpointer_orbax_emergency.py b/axlearn/common/checkpointer_orbax_emergency.py index f868488e0..9a0ea9363 100644 --- a/axlearn/common/checkpointer_orbax_emergency.py +++ b/axlearn/common/checkpointer_orbax_emergency.py @@ -314,6 +314,8 @@ def _init_consistent_proc_ids( # Then, rank 0 assigns inv_proc_id for worker that's missing their inv_proc_id and find the # coordinator address. if local_proc_info.cur_proc_id == 0: + jax_devices = [dev.id for dev in jax.devices()] + logging.info("jax_devices=%s", jax_devices) ids = client.key_value_dir_get(key_prefix) proc_infos: list[_ProcessInfo] = [] From 8ef02208993fe84fa750def495b3c166673131e8 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Wed, 6 Aug 2025 16:40:34 -0700 Subject: [PATCH 45/57] fsdp=256 data=-1 so ici_dp=1 --- axlearn/experiments/text/gpt/fuji.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 40fb010b7..7bfab409b 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -718,7 +718,7 @@ def get_trainer_kwargs( ChainConfigModifier.default_config().set( config_modifiers=[ MeshShapeModifier.default_config().set( - mesh_shape=mesh_shape_from_axes(data=-1, fsdp=128) + mesh_shape=mesh_shape_from_axes(data=-1, fsdp=256) ), RematSpecModifier.default_config().set( remat_policies={ From 2c2e134817a3e87465d22e5ebe4c130da038232a Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Thu, 7 Aug 2025 10:35:32 -0700 Subject: [PATCH 46/57] use latest main of orbax --- axlearn/common/checkpointer_orbax_emergency.py | 3 +++ pyproject.toml | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/axlearn/common/checkpointer_orbax_emergency.py b/axlearn/common/checkpointer_orbax_emergency.py index 9a0ea9363..a1786b974 100644 --- a/axlearn/common/checkpointer_orbax_emergency.py +++ b/axlearn/common/checkpointer_orbax_emergency.py @@ -90,6 +90,9 @@ def setup(spec: str): FLAGS.process_id = info.inv_proc_id FLAGS.distributed_coordinator = info.address FLAGS.experimental_orbax_use_distributed_process_id = True + # Required for case when slices swap and ici_dp=2 or higher. + # PR that introduced this flag: https://github.com/google/orbax/pull/2222 + FLAGS.experimental_use_distributed_id_for_mesh_consistency = False yield diff --git a/pyproject.toml b/pyproject.toml index 5e55b5cda..d861d0a82 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -156,7 +156,8 @@ mmau = [ orbax = [ "humanize==4.10.0", # "orbax-checkpoint==0.11.20", - "orbax-checkpoint @ git+https://github.com/samos123/orbax.git@v0.12.0#subdirectory=checkpoint" + # "orbax-checkpoint @ git+https://github.com/samos123/orbax.git@v0.12.0#subdirectory=checkpoint" + "orbax-checkpoint @ git+https://github.com/google/orbax.git@main#subdirectory=checkpoint" ] # Audio dependencies. audio = [ From 8b3a42512d54195f74fed74041752a5026159265 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Thu, 7 Aug 2025 10:40:54 -0700 Subject: [PATCH 47/57] 70b fsdp=64,data=-1 --- axlearn/experiments/text/gpt/fuji.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 7bfab409b..d7242e33e 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -718,7 +718,7 @@ def get_trainer_kwargs( ChainConfigModifier.default_config().set( config_modifiers=[ MeshShapeModifier.default_config().set( - mesh_shape=mesh_shape_from_axes(data=-1, fsdp=256) + mesh_shape=mesh_shape_from_axes(data=-1, fsdp=64) ), RematSpecModifier.default_config().set( remat_policies={ From ea6fc62926c1c0459aca7d50c572874abfc7c17d Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Thu, 7 Aug 2025 11:40:04 -0700 Subject: [PATCH 48/57] fsdp=32 with 70b --- axlearn/experiments/text/gpt/fuji.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index d7242e33e..8f7c250c6 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -718,7 +718,7 @@ def get_trainer_kwargs( ChainConfigModifier.default_config().set( config_modifiers=[ MeshShapeModifier.default_config().set( - mesh_shape=mesh_shape_from_axes(data=-1, fsdp=64) + mesh_shape=mesh_shape_from_axes(data=-1, fsdp=32) ), RematSpecModifier.default_config().set( remat_policies={ From 5f7037773c16eb88f4827b13a503599e7fbc9cea Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Thu, 7 Aug 2025 11:49:45 -0700 Subject: [PATCH 49/57] 70b fsdp=64 --- axlearn/experiments/text/gpt/fuji.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 8f7c250c6..d7242e33e 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -718,7 +718,7 @@ def get_trainer_kwargs( ChainConfigModifier.default_config().set( config_modifiers=[ MeshShapeModifier.default_config().set( - mesh_shape=mesh_shape_from_axes(data=-1, fsdp=32) + mesh_shape=mesh_shape_from_axes(data=-1, fsdp=64) ), RematSpecModifier.default_config().set( remat_policies={ From e33e04aa9fe2757d1bd5999a022ba468dd166563 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Thu, 7 Aug 2025 13:06:08 -0700 Subject: [PATCH 50/57] bump orbax to 0.11.21 --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d861d0a82..418d338e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -155,9 +155,9 @@ mmau = [ # Orbax checkpointing. orbax = [ "humanize==4.10.0", - # "orbax-checkpoint==0.11.20", + "orbax-checkpoint==0.11.21", # "orbax-checkpoint @ git+https://github.com/samos123/orbax.git@v0.12.0#subdirectory=checkpoint" - "orbax-checkpoint @ git+https://github.com/google/orbax.git@main#subdirectory=checkpoint" + # "orbax-checkpoint @ git+https://github.com/google/orbax.git@main#subdirectory=checkpoint" ] # Audio dependencies. audio = [ From dd144626beea831e3a1345f90a896e5e0a350c8f Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Thu, 7 Aug 2025 15:44:36 -0700 Subject: [PATCH 51/57] 7b fsdp=16 --- axlearn/experiments/text/gpt/fuji.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index d7242e33e..b73e3d6b9 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -384,7 +384,7 @@ def get_trainer_kwargs( max_sequence_length=max_sequence_length, train_batch_size=gbs, max_step=max_step, - mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), + mesh_shape=mesh_shape_from_axes(data=-1, fsdp=16), mesh_rules=( # Step time: # v1 on tpu-v4-1024 (512 chips): 3.03s @@ -456,7 +456,7 @@ def get_trainer_kwargs( ChainConfigModifier.default_config().set( config_modifiers=[ MeshShapeModifier.default_config().set( - mesh_shape=mesh_shape_from_axes(data=-1, fsdp=256) + mesh_shape=mesh_shape_from_axes(data=-1, fsdp=16) ), RematSpecModifier.default_config().set( remat_policies={ From 5b1401a0cfa0e44a51c69b19e127a7b9d33c8734 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Thu, 7 Aug 2025 19:51:18 -0700 Subject: [PATCH 52/57] use jun's fix for fsdp=16 data=16 --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 418d338e7..f9c150e8e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -155,8 +155,8 @@ mmau = [ # Orbax checkpointing. orbax = [ "humanize==4.10.0", - "orbax-checkpoint==0.11.21", - # "orbax-checkpoint @ git+https://github.com/samos123/orbax.git@v0.12.0#subdirectory=checkpoint" + # "orbax-checkpoint==0.11.21", + "orbax-checkpoint @ git+https://github.com/samos123/orbax.git@v0.11.21-em#subdirectory=checkpoint" # "orbax-checkpoint @ git+https://github.com/google/orbax.git@main#subdirectory=checkpoint" ] # Audio dependencies. From b3f67b9c2754de841480d427a77fe1bb98d5e88b Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 8 Aug 2025 09:00:20 -0700 Subject: [PATCH 53/57] fsdp=256 7b --- axlearn/experiments/text/gpt/fuji.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index b73e3d6b9..bed8e5f78 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -456,7 +456,7 @@ def get_trainer_kwargs( ChainConfigModifier.default_config().set( config_modifiers=[ MeshShapeModifier.default_config().set( - mesh_shape=mesh_shape_from_axes(data=-1, fsdp=16) + mesh_shape=mesh_shape_from_axes(data=-1, fsdp=256) ), RematSpecModifier.default_config().set( remat_policies={ From 09df9e943cc373d1920cc3387526fd58ae6520d4 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 8 Aug 2025 16:00:02 -0700 Subject: [PATCH 54/57] use orbax em with single replica GCS restore --- pyproject.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f9c150e8e..fc79d1766 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -156,7 +156,9 @@ mmau = [ orbax = [ "humanize==4.10.0", # "orbax-checkpoint==0.11.21", - "orbax-checkpoint @ git+https://github.com/samos123/orbax.git@v0.11.21-em#subdirectory=checkpoint" + # "orbax-checkpoint @ git+https://github.com/samos123/orbax.git@v0.11.21-em#subdirectory=checkpoint" + # Single replica restore from GCS with Orbax EM + "orbax-checkpoint @ git+https://github.com/samos123/orbax.git@v0.11.21-em-sr#subdirectory=checkpoint" # "orbax-checkpoint @ git+https://github.com/google/orbax.git@main#subdirectory=checkpoint" ] # Audio dependencies. From edd48eede6c2eccaff56cf0668e3f0e11500b105 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 8 Aug 2025 16:01:07 -0700 Subject: [PATCH 55/57] 70b fsdp=64 --- axlearn/experiments/text/gpt/fuji.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index bed8e5f78..711613958 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -669,7 +669,7 @@ def get_trainer_kwargs( ChainConfigModifier.default_config().set( config_modifiers=[ MeshShapeModifier.default_config().set( - mesh_shape=mesh_shape_from_axes(data=-1, fsdp=256) + mesh_shape=mesh_shape_from_axes(data=-1, fsdp=64) ), RematSpecModifier.default_config().set( remat_policies={ From ff64d65f2e759e5c745ec06ba7e0c06023704bfe Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 8 Aug 2025 22:49:22 -0700 Subject: [PATCH 56/57] try with new orbax fix --- force_delete_terminating_pods.sh | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/force_delete_terminating_pods.sh b/force_delete_terminating_pods.sh index cc4665e50..dea647c28 100755 --- a/force_delete_terminating_pods.sh +++ b/force_delete_terminating_pods.sh @@ -5,7 +5,7 @@ # --- Configuration --- # The maximum duration (in seconds) a pod is allowed to be in the Terminating state. -STUCK_DURATION_SECONDS=1200 +STUCK_DURATION_SECONDS=300 # How often (in seconds) the script should check for stuck pods. CHECK_INTERVAL_SECONDS=60 diff --git a/pyproject.toml b/pyproject.toml index fc79d1766..b3a5e9d28 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -158,7 +158,7 @@ orbax = [ # "orbax-checkpoint==0.11.21", # "orbax-checkpoint @ git+https://github.com/samos123/orbax.git@v0.11.21-em#subdirectory=checkpoint" # Single replica restore from GCS with Orbax EM - "orbax-checkpoint @ git+https://github.com/samos123/orbax.git@v0.11.21-em-sr#subdirectory=checkpoint" + "orbax-checkpoint @ git+https://github.com/samos123/orbax.git@v0.11.21-em-sr2#subdirectory=checkpoint" # "orbax-checkpoint @ git+https://github.com/google/orbax.git@main#subdirectory=checkpoint" ] # Audio dependencies. From 056e0ea0a99e838ac0b2043d355806de8a78e36d Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 15 Aug 2025 21:44:13 -0700 Subject: [PATCH 57/57] update orbax --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b3a5e9d28..1d346cc03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -158,7 +158,7 @@ orbax = [ # "orbax-checkpoint==0.11.21", # "orbax-checkpoint @ git+https://github.com/samos123/orbax.git@v0.11.21-em#subdirectory=checkpoint" # Single replica restore from GCS with Orbax EM - "orbax-checkpoint @ git+https://github.com/samos123/orbax.git@v0.11.21-em-sr2#subdirectory=checkpoint" + "orbax-checkpoint @ git+https://github.com/samos123/orbax.git@v0.11.21-em-aug13#subdirectory=checkpoint" # "orbax-checkpoint @ git+https://github.com/google/orbax.git@main#subdirectory=checkpoint" ] # Audio dependencies.