From a4c641c381b95a837603c137ce76745e364db14a Mon Sep 17 00:00:00 2001 From: lkolluru05 Date: Fri, 6 Jun 2025 20:55:18 +0000 Subject: [PATCH 1/3] orbax to reg checkpointer conversion --- axlearn/cloud/gcp/jobset_utils.py | 10 ++--- .../common/checkpointer_orbax_emergency.py | 9 ++++ axlearn/experiments/text/gpt/common.py | 41 ++++++++++++++++--- axlearn/experiments/text/gpt/fuji.py | 25 ++++++++--- 4 files changed, 69 insertions(+), 16 deletions(-) diff --git a/axlearn/cloud/gcp/jobset_utils.py b/axlearn/cloud/gcp/jobset_utils.py index ab3a7daaf..5979bb2b2 100644 --- a/axlearn/cloud/gcp/jobset_utils.py +++ b/axlearn/cloud/gcp/jobset_utils.py @@ -508,10 +508,10 @@ def _build_uploader_container( dst = f"{cfg.output_dir}/output/$HOSTNAME/" interval_s = 60 sync_command = f"while true; do gsutil -m rsync -r {src} {dst}; sleep {interval_s}; done" - resources = { - "requests": {"cpu": "100m", "memory": "128Mi"}, - "limits": {"cpu": "500m", "memory": "256Mi"}, - } + # resources = { + # "requests": {"cpu": "100m", "memory": "128Mi"}, + # "limits": {"cpu": "500m", "memory": "256Mi"}, + # } return dict( name="output-uploader", image="google/cloud-sdk:alpine", @@ -520,7 +520,7 @@ def _build_uploader_container( restartPolicy="Always", command=["/bin/sh", "-c"], args=[sync_command], - resources=resources, + #resources=resources, volumeMounts=[output_volume_mount], ) diff --git a/axlearn/common/checkpointer_orbax_emergency.py b/axlearn/common/checkpointer_orbax_emergency.py index 30e97caff..4b1dace2b 100644 --- a/axlearn/common/checkpointer_orbax_emergency.py +++ b/axlearn/common/checkpointer_orbax_emergency.py @@ -54,6 +54,11 @@ FLAGS = flags.FLAGS +def save_axlearn_checkpoint(step: int, state, directory: str, name: str): + cfg = Checkpointer.default_config().set(name=name, dir=directory) + ckpt = cfg.instantiate(parent=None) + ckpt.save(step=step, state=state) + ckpt.wait_until_finished() @contextmanager def setup(spec: str): @@ -819,6 +824,10 @@ def restore( ) time_diff = time.perf_counter() - start_t logging.info("Took %ss to restore emergency checkpoint from %s.", time_diff, cfg.dir) + + logging.info("Saving a non-Orbax checkpoint from the restored Orbax state...") + save_axlearn_checkpoint(step, restored_state, cfg.dir, cfg.name) + return step, restored_state def wait_until_finished(self): diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index 57d606dab..5bd66af89 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -643,6 +643,7 @@ def get_trainer_config_fn( keep_every_n_steps: int = 50_000, save_every_n_steps: Optional[int] = None, init_state_builder: Optional[state_builder.Builder.Config] = None, + checkpointer: str = "", ) -> TrainerConfigFn: """Builds a TrainerConfigFn according to the model and input specs. @@ -710,12 +711,40 @@ def config_fn() -> InstantiableConfig: ) cfg.evalers[name] = evaler_cfg # Summaries and checkpoints. - cfg.checkpointer.save_policy = config_for_function(every_n_steps_and_last_policy).set( - n=save_every_n_steps or min(eval_every_n_steps, 5_000), - max_step=max_step, - ) - cfg.checkpointer.keep_every_n_steps = min(max_step, keep_every_n_steps) - cfg.checkpointer.keep_last_n = 3 + calculated_save_every_n_steps = save_every_n_steps or min(eval_every_n_steps, 500) + + if not checkpointer: + cfg.checkpointer.save_policy = config_for_function(every_n_steps_and_last_policy).set( + n=calculated_save_every_n_steps, + max_step=max_step, + ) + cfg.checkpointer.keep_every_n_steps = min(max_step, keep_every_n_steps) + cfg.checkpointer.keep_last_n = 3 + elif checkpointer == "OrbaxEmergencyCheckpointer": + # Prevent global dependency on Orbax. + # pylint: disable-next=import-outside-toplevel + from axlearn.common.checkpointer_orbax_emergency import OrbaxEmergencyCheckpointer + + ckpt_config: OrbaxEmergencyCheckpointer.Config = ( + OrbaxEmergencyCheckpointer.default_config() + ) + ckpt_config.save_policy = config_for_function(every_n_steps_and_last_policy).set( + # n=calculated_save_every_n_steps, + # Every 15 minures ore more recommended + n=200, + max_step=max_step, + ) + ckpt_config.local_save_policy = config_for_function(every_n_steps_and_last_policy).set( + # n=calculated_save_every_n_steps, + # Every 2 minutes or more generally recommended + n=30, + max_step=max_step, + ) + ckpt_config.local_dir = "/host-tmp/checkpoints" + ckpt_config.keep_every_n_steps = min(max_step, keep_every_n_steps) + ckpt_config.keep_last_n = 3 + ckpt_config.replica_axis_index = 1 + cfg.checkpointer = ckpt_config cfg.summary_writer.write_every_n_steps = min(eval_every_n_steps, 100) cfg.summary_writer.max_queue = 1000 if len(mesh_axis_names) != len(mesh_shape): diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 57c910c70..547e001e5 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -366,6 +366,9 @@ def get_trainer_kwargs( ), ) elif model_size == "7B": + import jax + + gbs = len(jax.devices()) trainer_kwargs = dict( model_kwargs=dict( num_layers=32, @@ -378,7 +381,7 @@ def get_trainer_kwargs( ), learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, - train_batch_size=train_batch_size, + train_batch_size=gbs, max_step=max_step, mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), mesh_rules=( @@ -633,6 +636,9 @@ def get_trainer_kwargs( ), ) elif model_size == "70B": + import jax + + devices = len(jax.devices()) trainer_kwargs = dict( model_kwargs=dict( num_layers=80, @@ -648,7 +654,7 @@ def get_trainer_kwargs( ), learner_kwargs=dict(peak_lr=1.5e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, - train_batch_size=train_batch_size, + train_batch_size=devices*1, max_step=max_step, mesh_shape=mesh_shape_from_axes(fsdp=-1), mesh_rules=( @@ -914,22 +920,30 @@ def trainer_configs( """ arch = "fuji" config_map = {} - for version, model_size, flash_attention in itertools.product( - Version, MODEL_SIZES, [True, False] + for version, model_size, flash_attention, use_orbax_emergency_ckpt in itertools.product( + Version, MODEL_SIZES, [True, False], [False, True] ): if model_size not in TOTAL_TOKENS[version]: # This combination does not exist. continue vocab_size = VOCAB_SIZE[version] + + current_suffix_parts = [] + if flash_attention: + current_suffix_parts.append("-flash") + if use_orbax_emergency_ckpt: + current_suffix_parts.append("-orbaxem") + current_suffix = "".join(current_suffix_parts) config_name = make_config_name( arch=arch, model_size=model_size, version=f"v{version.value}", - suffix="-flash" if flash_attention else "", + suffix=current_suffix, ) kwargs = get_trainer_kwargs( model_size, vocab_size=vocab_size, version=version, flash_attention=flash_attention ) max_sequence_length = kwargs.pop("max_sequence_length") + checkpointer_str = "OrbaxEmergencyCheckpointer" if use_orbax_emergency_ckpt else "" # pylint: disable-next=unexpected-keyword-arg,missing-kwoa config_map[config_name] = get_trainer_config_fn( train_input_source=train_input_source( @@ -939,6 +953,7 @@ def trainer_configs( evalers=evaler_config_dict( eval_input_sources(vocab_size=vocab_size, max_sequence_length=max_sequence_length), ), + checkpointer=checkpointer_str, **kwargs, ) From 647d9edfdab3dff8fe444093d18d427209f01be5 Mon Sep 17 00:00:00 2001 From: lkolluru05 Date: Mon, 9 Jun 2025 15:34:10 +0000 Subject: [PATCH 2/3] comments addressed --- axlearn/cloud/gcp/jobset_utils.py | 10 +++++----- axlearn/common/checkpointer_orbax_emergency.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/axlearn/cloud/gcp/jobset_utils.py b/axlearn/cloud/gcp/jobset_utils.py index 5979bb2b2..ab3a7daaf 100644 --- a/axlearn/cloud/gcp/jobset_utils.py +++ b/axlearn/cloud/gcp/jobset_utils.py @@ -508,10 +508,10 @@ def _build_uploader_container( dst = f"{cfg.output_dir}/output/$HOSTNAME/" interval_s = 60 sync_command = f"while true; do gsutil -m rsync -r {src} {dst}; sleep {interval_s}; done" - # resources = { - # "requests": {"cpu": "100m", "memory": "128Mi"}, - # "limits": {"cpu": "500m", "memory": "256Mi"}, - # } + resources = { + "requests": {"cpu": "100m", "memory": "128Mi"}, + "limits": {"cpu": "500m", "memory": "256Mi"}, + } return dict( name="output-uploader", image="google/cloud-sdk:alpine", @@ -520,7 +520,7 @@ def _build_uploader_container( restartPolicy="Always", command=["/bin/sh", "-c"], args=[sync_command], - #resources=resources, + resources=resources, volumeMounts=[output_volume_mount], ) diff --git a/axlearn/common/checkpointer_orbax_emergency.py b/axlearn/common/checkpointer_orbax_emergency.py index 4b1dace2b..b36f0237c 100644 --- a/axlearn/common/checkpointer_orbax_emergency.py +++ b/axlearn/common/checkpointer_orbax_emergency.py @@ -825,7 +825,7 @@ def restore( time_diff = time.perf_counter() - start_t logging.info("Took %ss to restore emergency checkpoint from %s.", time_diff, cfg.dir) - logging.info("Saving a non-Orbax checkpoint from the restored Orbax state...") + logging.info("Saving an AXLearn tensorstore from the restored Orbax state...") save_axlearn_checkpoint(step, restored_state, cfg.dir, cfg.name) return step, restored_state From 809557cb8b9507c4e157bf775241c11eb252e7b5 Mon Sep 17 00:00:00 2001 From: lkolluru05 Date: Tue, 17 Jun 2025 21:19:49 +0000 Subject: [PATCH 3/3] working code for orbax testing --- axlearn/common/checkpointer_orbax.py | 16 +++++++++++++--- axlearn/experiments/text/gpt/common.py | 17 ++++++++++++++++- axlearn/experiments/text/gpt/fuji.py | 15 +++++++++++++-- 3 files changed, 42 insertions(+), 6 deletions(-) diff --git a/axlearn/common/checkpointer_orbax.py b/axlearn/common/checkpointer_orbax.py index 2eb205b7f..3c83fa8da 100644 --- a/axlearn/common/checkpointer_orbax.py +++ b/axlearn/common/checkpointer_orbax.py @@ -15,6 +15,7 @@ import jax import orbax.checkpoint as ocp import tensorflow as tf +from contextlib import contextmanager from absl import logging from axlearn.common import utils @@ -45,6 +46,15 @@ _GRAIN_INSTALLED = False +@contextmanager +def setup(spec: str): + """Setups any required values as required by Orbax. + + + """ + + yield + class _TfIteratorHandler(ocp.type_handlers.TypeHandler): """Serializes tf.data.Iterator. @@ -237,7 +247,7 @@ def save_fn_with_summaries(step: int, last_saved_step: Optional[int]) -> bool: options=ocp.CheckpointManagerOptions( create=True, max_to_keep=cfg.keep_last_n, - enable_async_checkpointing=True, + enable_async_checkpointing=False, step_name_format=self._name_format, should_save_fn=save_fn_with_summaries, enable_background_delete=True, @@ -345,8 +355,8 @@ def _restore_args(x: Any) -> ocp.RestoreArgs: ) except FileNotFoundError as e: # Orbax raises FileNotFoundError if there are no checkpoints. - if step is not None: - raise ValueError(f"Failed to restore at step {step}.") from e + # if step is not None: + # raise ValueError(f"Failed to restore at step {step}.") from e logging.info("Could not find any completed checkpoints under %s: %s", cfg.dir, e) return None, state # Return the input state. diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index 5bd66af89..0edfa27c1 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -17,6 +17,7 @@ import jax.numpy as jnp import tensorflow as tf from jax.sharding import PartitionSpec +from absl import logging from axlearn.common import ( base_model, @@ -712,14 +713,28 @@ def config_fn() -> InstantiableConfig: cfg.evalers[name] = evaler_cfg # Summaries and checkpoints. calculated_save_every_n_steps = save_every_n_steps or min(eval_every_n_steps, 500) - + logging.info("checkpointer: %s",checkpointer) if not checkpointer: + logging.info("In no checkpointer") cfg.checkpointer.save_policy = config_for_function(every_n_steps_and_last_policy).set( n=calculated_save_every_n_steps, max_step=max_step, ) cfg.checkpointer.keep_every_n_steps = min(max_step, keep_every_n_steps) cfg.checkpointer.keep_last_n = 3 + elif checkpointer == "OrbaxCheckpointer": + logging.info("In orbax checkpointer") + from axlearn.common.checkpointer_orbax import OrbaxCheckpointer + + ckpt_config: OrbaxCheckpointer.Config = ( + OrbaxCheckpointer.default_config() + ) + ckpt_config.save_policy = config_for_function(every_n_steps_and_last_policy).set( + n=calculated_save_every_n_steps, + max_step=max_step, + ) + ckpt_config.keep_last_n = 3 + cfg.checkpointer = ckpt_config elif checkpointer == "OrbaxEmergencyCheckpointer": # Prevent global dependency on Orbax. # pylint: disable-next=import-outside-toplevel diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 547e001e5..21ea1adb0 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -31,6 +31,7 @@ RoFormerQKVLinear, StackedTransformerLayer, ) +from absl import logging from axlearn.common.base_layer import RematSpec from axlearn.common.config import config_for_function from axlearn.common.decoder import LmHead @@ -920,8 +921,13 @@ def trainer_configs( """ arch = "fuji" config_map = {} - for version, model_size, flash_attention, use_orbax_emergency_ckpt in itertools.product( - Version, MODEL_SIZES, [True, False], [False, True] + orbax_options = [ + (True, False), # use_orbax_emergency_ckpt = True, use_orbax_ckpt = False + (False, True), # use_orbax_emergency_ckpt = False, use_orbax_ckpt = True + (False, False), # Neither is used + ] + for version, model_size, flash_attention, (use_orbax_emergency_ckpt, use_orbax_ckpt) in itertools.product( + Version, MODEL_SIZES, [True, False], orbax_options ): if model_size not in TOTAL_TOKENS[version]: # This combination does not exist. continue @@ -932,7 +938,11 @@ def trainer_configs( current_suffix_parts.append("-flash") if use_orbax_emergency_ckpt: current_suffix_parts.append("-orbaxem") + elif use_orbax_ckpt: + current_suffix_parts.append("-orbax") + current_suffix = "".join(current_suffix_parts) + logging.info(current_suffix) config_name = make_config_name( arch=arch, model_size=model_size, @@ -944,6 +954,7 @@ def trainer_configs( ) max_sequence_length = kwargs.pop("max_sequence_length") checkpointer_str = "OrbaxEmergencyCheckpointer" if use_orbax_emergency_ckpt else "" + checkpointer_str = "OrbaxCheckpointer" if use_orbax_ckpt else "" # pylint: disable-next=unexpected-keyword-arg,missing-kwoa config_map[config_name] = get_trainer_config_fn( train_input_source=train_input_source(