diff --git a/Dockerfile b/Dockerfile index 29db664d3..dc011dcb0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -83,7 +83,7 @@ ENTRYPOINT ["/opt/apache/beam/boot"] FROM base AS tpu -ARG EXTRAS= +ARG EXTRAS=orbax ENV UV_FIND_LINKS=https://storage.googleapis.com/jax-releases/libtpu_releases.html # Ensure we install the TPU version, even if building locally. diff --git a/axlearn/cloud/gcp/job.py b/axlearn/cloud/gcp/job.py index 92cb8d045..5099aa3c3 100644 --- a/axlearn/cloud/gcp/job.py +++ b/axlearn/cloud/gcp/job.py @@ -134,7 +134,11 @@ def _build_jobset(self) -> Nested[Any]: return dict( metadata=dict(name=cfg.name, annotations=annotations), spec=dict( - failurePolicy=dict(maxRestarts=cfg.max_tries - 1), + failurePolicy=dict( + maxRestarts=cfg.max_tries - 1, + restartStrategy="BlockingRecreate" + # maxRestarts=cfg.max_tries - 1, + ), replicatedJobs=self._builder(), ), ) diff --git a/axlearn/cloud/gcp/jobset_utils.py b/axlearn/cloud/gcp/jobset_utils.py index ab3a7daaf..6a435fd79 100644 --- a/axlearn/cloud/gcp/jobset_utils.py +++ b/axlearn/cloud/gcp/jobset_utils.py @@ -451,6 +451,9 @@ def _build_container(self) -> Nested[Any]: if cfg.enable_tpu_ici_resiliency is not None: env_vars["ENABLE_ICI_RESILIENCY"] = str(cfg.enable_tpu_ici_resiliency).lower() + env_vars["TPU_PREMAPPED_BUFFER_SIZE"] = "137438953472" + env_vars["TPU_PREMAPPED_BUFFER_TRANSFER_THRESHOLD_BYTES"] = "137438953472" + resources = {"limits": {"google.com/tpu": system.chips_per_vm}} # Set request memory by host machine type. machine_memory_gi = GCE_MACHINE_TYPE_TO_MEMORY_CHARACTERISTICS.get( @@ -690,7 +693,7 @@ def _build_pod(self) -> Nested[Any]: spec = dict( # NOTE: Don't set hostNetwork or dnsPolicy for compat with Workload Identity. - terminationGracePeriodSeconds=60, + terminationGracePeriodSeconds=300, # Fail if any pod fails, and allow retries to happen at JobSet level. restartPolicy="Never", # https://kubernetes.io/docs/tasks/network/customize-hosts-file-for-pods/#adding-additional-entries-with-hostaliases diff --git a/axlearn/common/checkpointer_orbax.py b/axlearn/common/checkpointer_orbax.py index 2eb205b7f..c96c48d17 100644 --- a/axlearn/common/checkpointer_orbax.py +++ b/axlearn/common/checkpointer_orbax.py @@ -13,9 +13,12 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import jax +import numpy as np import orbax.checkpoint as ocp import tensorflow as tf from absl import logging +from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib +from orbax.checkpoint._src.serialization.type_handlers import ArrayHandler from axlearn.common import utils from axlearn.common.checkpointer import ( @@ -187,15 +190,18 @@ class Config(BaseCheckpointer.Config): Attributes: keep_last_n: Keep this many past ckpts. + keep_every_n_steps: If set, keep a checkpoint every n steps. validation_type: Checkpoint validation during restore. async_timeout_secs: Timeout for async barrier in seconds. """ keep_last_n: int = 1 + keep_every_n_steps: Optional[int] = None validation_type: CheckpointValidationType = CheckpointValidationType.EXACT async_timeout_secs: int = 300 max_concurrent_save_gb: Optional[int] = None max_concurrent_restore_gb: Optional[int] = None + enable_single_replica_ckpt_restoring: bool = True @classmethod def checkpoint_paths(cls, base_dir: str) -> List[str]: @@ -237,6 +243,7 @@ def save_fn_with_summaries(step: int, last_saved_step: Optional[int]) -> bool: options=ocp.CheckpointManagerOptions( create=True, max_to_keep=cfg.keep_last_n, + keep_period=cfg.keep_every_n_steps, enable_async_checkpointing=True, step_name_format=self._name_format, should_save_fn=save_fn_with_summaries, @@ -321,11 +328,33 @@ def restore( cfg: OrbaxCheckpointer.Config = self.config + if cfg.enable_single_replica_ckpt_restoring: + array_handler = ocp.type_handlers.SingleReplicaArrayHandler( + replica_axis_index=0, + broadcast_memory_limit_bytes=1024 * 1024 * 1000, # 1000 MB limit + ) + ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True) + def _restore_args(x: Any) -> ocp.RestoreArgs: if isinstance(x, (Tensor, TensorSpec)): - return ocp.checkpoint_utils.construct_restore_args( - jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding) - ) + if cfg.enable_single_replica_ckpt_restoring: + pspec = x.sharding.spec + mesh = x.sharding.mesh + replica_axis_index = 0 + replica_devices = _replica_devices(mesh.devices, replica_axis_index) + replica_mesh = jax.sharding.Mesh(replica_devices, mesh.axis_names) + single_replica_sharding = jax.sharding.NamedSharding(replica_mesh, pspec) + + return ocp.type_handlers.SingleReplicaArrayRestoreArgs( + sharding=jax.sharding.NamedSharding(mesh, pspec), + single_replica_sharding=single_replica_sharding, + global_shape=x.shape, + dtype=x.dtype, + ) + else: + return ocp.checkpoint_utils.construct_restore_args( + jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding) + ) elif isinstance(x, tf.data.Iterator): return _TfIteratorHandler.RestoreArgs(item=x) elif _GRAIN_INSTALLED and isinstance(x, _GrainIterator): @@ -349,6 +378,13 @@ def _restore_args(x: Any) -> ocp.RestoreArgs: raise ValueError(f"Failed to restore at step {step}.") from e logging.info("Could not find any completed checkpoints under %s: %s", cfg.dir, e) return None, state # Return the input state. + finally: + if cfg.enable_single_replica_ckpt_restoring: + ocp.type_handlers.register_type_handler( + jax.Array, + ArrayHandler(array_metadata_store=array_metadata_store_lib.Store()), + override=True, + ) restored_index = composite_state["index"] restored_state = composite_state["state"] @@ -375,3 +411,29 @@ def wait_until_finished(self): def stop(self, *, has_exception: bool = False): """See `BaseCheckpointer.stop` for details.""" self._manager.close() + + +def _find_idx(array: np.ndarray, replica_axis_idx: int): + """Returns the index along given dimension that the current host belongs to.""" + idx = None + for idx, val in np.ndenumerate(array): + if val.process_index == jax.process_index(): + break + return idx[replica_axis_idx] + + +def _replica_devices(device_array: np.ndarray, replica_axis_idx: int): + """Returns the devices from the replica that current host belongs to. + + Replicas are assumed to be restricted to the first axis. + + Args: + device_array: devices of the mesh that can be obtained by mesh.devices() + replica_axis_idx: axis dimension along which replica is taken + + Returns: + devices inside the replica that current host is in + """ + idx = _find_idx(device_array, replica_axis_idx) + replica_result = np.take(device_array, idx, axis=replica_axis_idx) + return np.expand_dims(replica_result, axis=replica_axis_idx) diff --git a/axlearn/common/checkpointer_orbax_emergency.py b/axlearn/common/checkpointer_orbax_emergency.py index f868488e0..a1786b974 100644 --- a/axlearn/common/checkpointer_orbax_emergency.py +++ b/axlearn/common/checkpointer_orbax_emergency.py @@ -90,6 +90,9 @@ def setup(spec: str): FLAGS.process_id = info.inv_proc_id FLAGS.distributed_coordinator = info.address FLAGS.experimental_orbax_use_distributed_process_id = True + # Required for case when slices swap and ici_dp=2 or higher. + # PR that introduced this flag: https://github.com/google/orbax/pull/2222 + FLAGS.experimental_use_distributed_id_for_mesh_consistency = False yield @@ -314,6 +317,8 @@ def _init_consistent_proc_ids( # Then, rank 0 assigns inv_proc_id for worker that's missing their inv_proc_id and find the # coordinator address. if local_proc_info.cur_proc_id == 0: + jax_devices = [dev.id for dev in jax.devices()] + logging.info("jax_devices=%s", jax_devices) ids = client.key_value_dir_get(key_prefix) proc_infos: list[_ProcessInfo] = [] diff --git a/axlearn/common/compiler_options.py b/axlearn/common/compiler_options.py index 06239e54b..793989144 100644 --- a/axlearn/common/compiler_options.py +++ b/axlearn/common/compiler_options.py @@ -58,7 +58,12 @@ def default_xla_options( # cause the step time to be double. You should increase this # further if you see "Allocator failed to allocate". A feature # to dynamically allocate may come later: b/380514965 - megascale_grpc_premap_memory_bytes=17179869184, + # Needed for orbax emergency checkpointer + # Needed for decent perf + megascale_grpc_premap_memory_bytes=137438953472, + # needed for restore consistent hash + megascale_jax_offset_launch_id_by_module_name=True, + megascale_jax_use_device_set_based_launch_id=False, # Flag controlling the maximum number of overlapping host offloadings. xla_tpu_host_transfer_overlap_limit=24, # Flag controlling the maximum number of overlapping cross-DCN send/recv. diff --git a/axlearn/common/trainer.py b/axlearn/common/trainer.py index 0603f7bf9..692c2e24a 100644 --- a/axlearn/common/trainer.py +++ b/axlearn/common/trainer.py @@ -613,7 +613,7 @@ def run( ) self.vlog(3, "Done step %s", self.step) num_steps += 1 - if num_steps % 100 == 0: + if num_steps % 1 == 0: now = time.perf_counter() average_step_time = (now - start_time) / num_steps self._step_log("Average step time: %s seconds", average_step_time) @@ -927,42 +927,40 @@ def restore_checkpoint(self, restore_step: Optional[int] = None) -> Optional[int **ckpt_state_spec, input_iter=iter(self.input.dataset()) ) restore_input_iter = cfg.save_input_iterator - try: - # Try to restore with `input_iter`. - step, ckpt_state = self.checkpointer.restore( - step=restore_step, - state=( - ckpt_state_spec_with_input_iter if restore_input_iter else ckpt_state_spec - ), - ) - if step is not None: - self.vlog( - 0, - "Restored checkpoint at %s with restore_input_iter=%s", - step, - restore_input_iter, - ) - except ValueError as e: - logging.warning( - "Attempt to restore checkpoint with restore_input_iter=%s failed: %s", + # try: + # Try to restore with `input_iter`. + step, ckpt_state = self.checkpointer.restore( + step=restore_step, + state=(ckpt_state_spec_with_input_iter if restore_input_iter else ckpt_state_spec), + ) + if step is not None: + self.vlog( + 0, + "Restored checkpoint at %s with restore_input_iter=%s", + step, restore_input_iter, - e, - ) - # Restore with a different restore_input_iter setting. - restore_input_iter = not restore_input_iter - step, ckpt_state = self.checkpointer.restore( - step=restore_step, - state=( - ckpt_state_spec_with_input_iter if restore_input_iter else ckpt_state_spec - ), ) - if step is not None: - self.vlog( - 0, - "Restored checkpoint at %s with restore_input_iter=%s", - step, - restore_input_iter, - ) + # except ValueError as e: + # logging.warning( + # "Attempt to restore checkpoint with restore_input_iter=%s failed: %s", + # restore_input_iter, + # e, + # ) + # # Restore with a different restore_input_iter setting. + # restore_input_iter = not restore_input_iter + # step, ckpt_state = self.checkpointer.restore( + # step=restore_step, + # state=( + # ckpt_state_spec_with_input_iter if restore_input_iter else ckpt_state_spec + # ), + # ) + # if step is not None: + # self.vlog( + # 0, + # "Restored checkpoint at %s with restore_input_iter=%s", + # step, + # restore_input_iter, + # ) if step is not None: self._step = step self._trainer_state = TrainerState( @@ -1097,7 +1095,7 @@ def _run_step( # Run the compiled function. self._trainer_state, outputs = compiled_train_step_fn(self.trainer_state, input_batch) - if self.step % 100 == 0 or 0 <= self.step <= 5: + if self.step % 1 == 0 or 0 <= self.step <= 5: self._step_log( "loss=%s aux=%s", outputs["loss"], diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index 57d606dab..d24622a50 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -643,6 +643,7 @@ def get_trainer_config_fn( keep_every_n_steps: int = 50_000, save_every_n_steps: Optional[int] = None, init_state_builder: Optional[state_builder.Builder.Config] = None, + checkpointer: str = "", ) -> TrainerConfigFn: """Builds a TrainerConfigFn according to the model and input specs. @@ -709,13 +710,56 @@ def config_fn() -> InstantiableConfig: ) ) cfg.evalers[name] = evaler_cfg + # Summaries and checkpoints. - cfg.checkpointer.save_policy = config_for_function(every_n_steps_and_last_policy).set( - n=save_every_n_steps or min(eval_every_n_steps, 5_000), - max_step=max_step, - ) - cfg.checkpointer.keep_every_n_steps = min(max_step, keep_every_n_steps) - cfg.checkpointer.keep_last_n = 3 + calculated_save_every_n_steps = save_every_n_steps or min(eval_every_n_steps, 100) + + if not checkpointer: + cfg.checkpointer.save_policy = config_for_function(every_n_steps_and_last_policy).set( + n=calculated_save_every_n_steps, + max_step=max_step, + ) + cfg.checkpointer.keep_every_n_steps = min(max_step, keep_every_n_steps) + cfg.checkpointer.keep_last_n = 3 + elif checkpointer == "OrbaxEmergencyCheckpointer": + # Prevent global dependency on Orbax. + # pylint: disable-next=import-outside-toplevel + from axlearn.common.checkpointer_orbax_emergency import OrbaxEmergencyCheckpointer + + ckpt_config: OrbaxEmergencyCheckpointer.Config = ( + OrbaxEmergencyCheckpointer.default_config() + ) + ckpt_config.save_policy = config_for_function(every_n_steps_and_last_policy).set( + n=calculated_save_every_n_steps, + max_step=max_step, + ) + ckpt_config.local_save_policy = config_for_function(every_n_steps_and_last_policy).set( + n=20, + max_step=max_step, + ) + ckpt_config.local_dir = "/host-tmp/checkpoints" + ckpt_config.keep_every_n_steps = min(max_step, keep_every_n_steps) + ckpt_config.keep_last_n = 3 + ckpt_config.replica_axis_index = 1 + cfg.checkpointer = ckpt_config + elif checkpointer == "OrbaxRegularCheckpointer": + # Prevent global dependency on Orbax. + # pylint: disable-next=import-outside-toplevel + from axlearn.common.checkpointer_orbax import OrbaxCheckpointer + + ckpt_config: OrbaxCheckpointer.Config = OrbaxCheckpointer.default_config() + ckpt_config.save_policy = config_for_function(every_n_steps_and_last_policy).set( + n=calculated_save_every_n_steps, + max_step=max_step, + ) + ckpt_config.keep_every_n_steps = min(max_step, keep_every_n_steps) + ckpt_config.keep_last_n = 3 + cfg.checkpointer = ckpt_config + + # Save the data iterator as part of the checkpointing process. + # default is false. + # cfg.save_input_iterator = True + cfg.summary_writer.write_every_n_steps = min(eval_every_n_steps, 100) cfg.summary_writer.max_queue = 1000 if len(mesh_axis_names) != len(mesh_shape): diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 57c910c70..711613958 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -366,6 +366,10 @@ def get_trainer_kwargs( ), ) elif model_size == "7B": + # pylint: disable=import-outside-toplevel + import jax + + gbs = len(jax.devices()) trainer_kwargs = dict( model_kwargs=dict( num_layers=32, @@ -378,9 +382,9 @@ def get_trainer_kwargs( ), learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, - train_batch_size=train_batch_size, + train_batch_size=gbs, max_step=max_step, - mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), + mesh_shape=mesh_shape_from_axes(data=-1, fsdp=16), mesh_rules=( # Step time: # v1 on tpu-v4-1024 (512 chips): 3.03s @@ -633,6 +637,10 @@ def get_trainer_kwargs( ), ) elif model_size == "70B": + # pylint: disable=import-outside-toplevel + import jax + + gbs = len(jax.devices()) trainer_kwargs = dict( model_kwargs=dict( num_layers=80, @@ -648,7 +656,7 @@ def get_trainer_kwargs( ), learner_kwargs=dict(peak_lr=1.5e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, - train_batch_size=train_batch_size, + train_batch_size=gbs, max_step=max_step, mesh_shape=mesh_shape_from_axes(fsdp=-1), mesh_rules=( @@ -661,7 +669,7 @@ def get_trainer_kwargs( ChainConfigModifier.default_config().set( config_modifiers=[ MeshShapeModifier.default_config().set( - mesh_shape=mesh_shape_from_axes(data=-1, fsdp=256) + mesh_shape=mesh_shape_from_axes(data=-1, fsdp=64) ), RematSpecModifier.default_config().set( remat_policies={ @@ -710,7 +718,7 @@ def get_trainer_kwargs( ChainConfigModifier.default_config().set( config_modifiers=[ MeshShapeModifier.default_config().set( - mesh_shape=mesh_shape_from_axes(data=-1, fsdp=256) + mesh_shape=mesh_shape_from_axes(data=-1, fsdp=64) ), RematSpecModifier.default_config().set( remat_policies={ @@ -914,17 +922,30 @@ def trainer_configs( """ arch = "fuji" config_map = {} - for version, model_size, flash_attention in itertools.product( - Version, MODEL_SIZES, [True, False] + for version, model_size, flash_attention, checkpointer in itertools.product( + Version, + MODEL_SIZES, + [True, False], + ["", "OrbaxEmergencyCheckpointer", "OrbaxRegularCheckpointer"], ): if model_size not in TOTAL_TOKENS[version]: # This combination does not exist. continue vocab_size = VOCAB_SIZE[version] + + current_suffix_parts = [] + if flash_attention: + current_suffix_parts.append("-flash") + if checkpointer == "OrbaxEmergencyCheckpointer": + current_suffix_parts.append("-orbaxem") + elif checkpointer == "OrbaxRegularCheckpointer": + current_suffix_parts.append("-orbax") + current_suffix = "".join(current_suffix_parts) + config_name = make_config_name( arch=arch, model_size=model_size, version=f"v{version.value}", - suffix="-flash" if flash_attention else "", + suffix=current_suffix, ) kwargs = get_trainer_kwargs( model_size, vocab_size=vocab_size, version=version, flash_attention=flash_attention @@ -939,6 +960,7 @@ def trainer_configs( evalers=evaler_config_dict( eval_input_sources(vocab_size=vocab_size, max_sequence_length=max_sequence_length), ), + checkpointer=checkpointer, **kwargs, ) diff --git a/force_delete_terminating_pods.sh b/force_delete_terminating_pods.sh new file mode 100755 index 000000000..dea647c28 --- /dev/null +++ b/force_delete_terminating_pods.sh @@ -0,0 +1,64 @@ +#!/bin/bash +# +# This script finds pods that are stuck in the 'Terminating' state for more +# than a specified duration and forcibly deletes them. + +# --- Configuration --- +# The maximum duration (in seconds) a pod is allowed to be in the Terminating state. +STUCK_DURATION_SECONDS=300 + +# How often (in seconds) the script should check for stuck pods. +CHECK_INTERVAL_SECONDS=60 +# --- End Configuration --- + +echo "Starting pod termination monitor..." +echo "Stuck Threshold: ${STUCK_DURATION_SECONDS}s | Check Interval: ${CHECK_INTERVAL_SECONDS}s" + +while true; do + echo "$(date '+%Y-%m-%d %H:%M:%S') - Checking for stuck pods..." + + # Get all pods with a deletion timestamp in JSON format. + # The 'deletionTimestamp' field is only present for pods that are being terminated. + stuck_pods_json=$(kubectl get pods -o json | jq -c '.items[] | select(.metadata.deletionTimestamp)') + + if [ -z "$stuck_pods_json" ]; then + echo "No pods are currently in a terminating state." + else + echo "$stuck_pods_json" | while read -r pod_json; do + # Extract details from the pod's JSON data + pod_name=$(echo "$pod_json" | jq -r '.metadata.name') + pod_namespace=$(echo "$pod_json" | jq -r '.metadata.namespace') + deletion_timestamp_str=$(echo "$pod_json" | jq -r '.metadata.deletionTimestamp') + + # Sanitize the timestamp for macOS `date` by removing fractional seconds and the 'Z' suffix. + # This handles formats like "2024-01-01T12:34:56.123456Z" -> "2024-01-01T12:34:56" + sanitized_timestamp_str=$(echo "$deletion_timestamp_str" | sed -e 's/\.[0-9]*Z$/Z/' -e 's/Z$//') + + # Convert the RFC3339 timestamp to a Unix epoch timestamp + # Works on both GNU and BSD (macOS) date commands. + if date --version >/dev/null 2>&1; then # GNU date + deletion_ts=$(date -d "$sanitized_timestamp_str" +%s) + else # BSD date + # On macOS, use -u to interpret the time as UTC. + deletion_ts=$(date -u -jf "%Y-%m-%dT%H:%M:%S" "$sanitized_timestamp_str" +%s) + fi + + # Get the current time as a Unix epoch timestamp + now_ts=$(date +%s) + + # Calculate how long the pod has been terminating + duration=$((now_ts - deletion_ts)) + + echo " - Checking pod '$pod_name' in namespace '$pod_namespace' (terminating for ${duration}s)" + + if [ "$duration" -gt "$STUCK_DURATION_SECONDS" ]; then + echo " -> STUCK! Pod '$pod_name' has been terminating for ${duration}s. Forcing deletion." + # Force delete the pod. The --grace-period=0 is crucial. + kubectl delete pod "$pod_name" -n "$pod_namespace" --force --grace-period=0 + fi + done + fi + + echo "Check complete. Sleeping for ${CHECK_INTERVAL_SECONDS} seconds..." + sleep "$CHECK_INTERVAL_SECONDS" +done diff --git a/pyproject.toml b/pyproject.toml index 57e415e7e..1d346cc03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ core = [ "protobuf>=3.20.3", "tensorboard-plugin-profile==2.15.1", # This has both x86 and arm64 wheels. Underneath the hood it uses tensorflow-macos since 2.13. + "tokenizers", "tensorflow==2.17.1", "tensorflow-datasets>=4.9.2", "tensorflow-io>=0.37.1", # for tensorflow-2.16. Note that 0.37.0 results in "pure virtual method called". @@ -154,7 +155,11 @@ mmau = [ # Orbax checkpointing. orbax = [ "humanize==4.10.0", - "orbax-checkpoint==0.11.15", + # "orbax-checkpoint==0.11.21", + # "orbax-checkpoint @ git+https://github.com/samos123/orbax.git@v0.11.21-em#subdirectory=checkpoint" + # Single replica restore from GCS with Orbax EM + "orbax-checkpoint @ git+https://github.com/samos123/orbax.git@v0.11.21-em-aug13#subdirectory=checkpoint" + # "orbax-checkpoint @ git+https://github.com/google/orbax.git@main#subdirectory=checkpoint" ] # Audio dependencies. audio = [ diff --git a/test-orbax.sh b/test-orbax.sh new file mode 100755 index 000000000..7f9e13744 --- /dev/null +++ b/test-orbax.sh @@ -0,0 +1,88 @@ +#!/usr/bin/env bash + +set -xe + +export NUM_REPLICAS=${NUM_REPLICAS:-2} +export JOBSET_NAME=${JOBSET_NAME:-$USER} +export BASTION_TIER=disabled +export GKE_CLUSTER=$(axlearn gcp config | grep gke_cluster | awk '{ print $3 }' | tr -d '"') +# Switch to tpu-v6e-256 if on scale cluster +export INSTANCE_TYPE=${INSTANCE_TYPE:-"tpu-v6e-16"} +# Switch to tpu-v6e-256-4 if on scale cluster +export MESH_SELECTOR=${MESH_SELECTOR:-"tpu-v6e-16"} +# Need to use tiktoken when saving data iterator +# export CONFIG=${CONFIG:-"fuji-8B-v3-tiktoken-flash-orbax"} +export CONFIG=${CONFIG:-"fuji-7B-v3-flash-orbaxem"} +export PROJECT_ID=$(gcloud config get project) +# export TRAINER_DIR=gs://tpu-prod-env-multipod-use4 +export TRAINER_DIR=gs://tpu-prod-env-one-vm-saw1-a + + +# Example for v6e-256 +# MESH_SELECTOR=tpu-v6e-256-4 INSTANCE_TYPE=tpu-v6e-256 ./test-orbax.sh + +# The bundle step is needed if you run on cloudtop +# uncomment if you use cloudtop +# axlearn gcp bundle --name=$JOBSET_NAME \ +# --bundler_spec=allow_dirty=True \ +# --bundler_type=artifactregistry \ +# --bundler_spec=dockerfile=Dockerfile \ +# --bundler_spec=image=tpu \ +# --bundler_spec=target=tpu + +# Only enable kueue when running on scale testing cluster +# --queue=multislice-queue \ +# --priority_class=very-high \ +# --trainer_dir=gs://tess-checkpoints-us-west1/${JOBSET_NAME}-nr-${NUM_REPLICAS}/ \ +# For goodput logging +# --recorder_type=axlearn.cloud.gcp.measurement:goodput \ +# --recorder_spec=name=goodput_${JOBSET_NAME} \ +# --recorder_spec=upload_dir=${TRAINER_DIR}/summaries \ +# --recorder_spec=upload_interval=30 \ +# --recorder_spec=rolling_window_size=3600,7200,10800,86400 \ +# --trace_at_steps=29,59,89,119,149,179,209,239,269,299,329,359,389,419,449,479,509,539,569,599,629,659,689,719 + +# Check if CONFIG ends with "orbaxem" +if [[ "$CONFIG" == *"orbaxem"* ]]; then + echo "Running with Orbax emergency checkpointer." + axlearn gcp launch run --cluster=$GKE_CLUSTER \ + --runner_name gke_tpu_single \ + --queue=multislice-queue \ + --priority_class=very-high \ + --name=$JOBSET_NAME \ + --instance_type=${INSTANCE_TYPE} \ + --host_mount_spec=name=tmp,host_path=/tmp,mount_path=/host-tmp \ + --num_replicas=${NUM_REPLICAS} \ + --bundler_spec=allow_dirty=True \ + --bundler_type=artifactregistry --bundler_spec=image=tpu \ + --bundler_spec=dockerfile=Dockerfile --bundler_spec=target=tpu \ + -- "ulimit -n 1048576; ulimit -c 0; python3 -c 'import jax; jax.devices()'; python3 -m axlearn.common.launch_trainer_main" \ + --init_module=axlearn.common.checkpointer_orbax_emergency:local_ckpt_dir=/host-tmp/checkpoints \ + --module=text.gpt.c4_trainer \ + --config=${CONFIG} \ + --trainer_dir=${TRAINER_DIR}/${JOBSET_NAME} \ + --data_dir=gs://axlearn-public/tensorflow_datasets \ + --jax_backend=tpu \ + --mesh_selector=${MESH_SELECTOR} \ + --initialization_timeout=1200 + +else + echo "Running Orbax regular checkpointer or AXLearn native." + axlearn gcp launch run --cluster=$GKE_CLUSTER \ + --runner_name gke_tpu_single \ + --name=$JOBSET_NAME \ + --instance_type=${INSTANCE_TYPE} \ + --num_replicas=${NUM_REPLICAS} \ + --bundler_spec=allow_dirty=True \ + --bundler_type=artifactregistry --bundler_spec=image=tpu \ + --bundler_spec=dockerfile=Dockerfile --bundler_spec=target=tpu \ + -- "ulimit -n 1048576; ulimit -c 0; python3 -c 'import jax; jax.devices()'; python3 -m axlearn.common.launch_trainer_main" \ + --module=text.gpt.c4_trainer \ + --config=${CONFIG} \ + --trainer_dir=gs://${PROJECT_ID}-axlearn/${JOBSET_NAME}-nr-${NUM_REPLICAS}/ \ + --data_dir=gs://axlearn-public/tensorflow_datasets \ + --jax_backend=tpu \ + --mesh_selector=${MESH_SELECTOR} \ + --initialization_timeout=1200 \ + --trace_at_steps=29,59,89,119,149,179,209,239,269,299,329,359,389,419,449,479,509,539,569,599,629,659,689,719 +fi