diff --git a/axlearn/common/checkpointer_orbax.py b/axlearn/common/checkpointer_orbax.py index 2eb205b7f..3c83fa8da 100644 --- a/axlearn/common/checkpointer_orbax.py +++ b/axlearn/common/checkpointer_orbax.py @@ -15,6 +15,7 @@ import jax import orbax.checkpoint as ocp import tensorflow as tf +from contextlib import contextmanager from absl import logging from axlearn.common import utils @@ -45,6 +46,15 @@ _GRAIN_INSTALLED = False +@contextmanager +def setup(spec: str): + """Setups any required values as required by Orbax. + + + """ + + yield + class _TfIteratorHandler(ocp.type_handlers.TypeHandler): """Serializes tf.data.Iterator. @@ -237,7 +247,7 @@ def save_fn_with_summaries(step: int, last_saved_step: Optional[int]) -> bool: options=ocp.CheckpointManagerOptions( create=True, max_to_keep=cfg.keep_last_n, - enable_async_checkpointing=True, + enable_async_checkpointing=False, step_name_format=self._name_format, should_save_fn=save_fn_with_summaries, enable_background_delete=True, @@ -345,8 +355,8 @@ def _restore_args(x: Any) -> ocp.RestoreArgs: ) except FileNotFoundError as e: # Orbax raises FileNotFoundError if there are no checkpoints. - if step is not None: - raise ValueError(f"Failed to restore at step {step}.") from e + # if step is not None: + # raise ValueError(f"Failed to restore at step {step}.") from e logging.info("Could not find any completed checkpoints under %s: %s", cfg.dir, e) return None, state # Return the input state. diff --git a/axlearn/common/checkpointer_orbax_emergency.py b/axlearn/common/checkpointer_orbax_emergency.py index 30e97caff..b36f0237c 100644 --- a/axlearn/common/checkpointer_orbax_emergency.py +++ b/axlearn/common/checkpointer_orbax_emergency.py @@ -54,6 +54,11 @@ FLAGS = flags.FLAGS +def save_axlearn_checkpoint(step: int, state, directory: str, name: str): + cfg = Checkpointer.default_config().set(name=name, dir=directory) + ckpt = cfg.instantiate(parent=None) + ckpt.save(step=step, state=state) + ckpt.wait_until_finished() @contextmanager def setup(spec: str): @@ -819,6 +824,10 @@ def restore( ) time_diff = time.perf_counter() - start_t logging.info("Took %ss to restore emergency checkpoint from %s.", time_diff, cfg.dir) + + logging.info("Saving an AXLearn tensorstore from the restored Orbax state...") + save_axlearn_checkpoint(step, restored_state, cfg.dir, cfg.name) + return step, restored_state def wait_until_finished(self): diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index 57d606dab..0edfa27c1 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -17,6 +17,7 @@ import jax.numpy as jnp import tensorflow as tf from jax.sharding import PartitionSpec +from absl import logging from axlearn.common import ( base_model, @@ -643,6 +644,7 @@ def get_trainer_config_fn( keep_every_n_steps: int = 50_000, save_every_n_steps: Optional[int] = None, init_state_builder: Optional[state_builder.Builder.Config] = None, + checkpointer: str = "", ) -> TrainerConfigFn: """Builds a TrainerConfigFn according to the model and input specs. @@ -710,12 +712,54 @@ def config_fn() -> InstantiableConfig: ) cfg.evalers[name] = evaler_cfg # Summaries and checkpoints. - cfg.checkpointer.save_policy = config_for_function(every_n_steps_and_last_policy).set( - n=save_every_n_steps or min(eval_every_n_steps, 5_000), - max_step=max_step, - ) - cfg.checkpointer.keep_every_n_steps = min(max_step, keep_every_n_steps) - cfg.checkpointer.keep_last_n = 3 + calculated_save_every_n_steps = save_every_n_steps or min(eval_every_n_steps, 500) + logging.info("checkpointer: %s",checkpointer) + if not checkpointer: + logging.info("In no checkpointer") + cfg.checkpointer.save_policy = config_for_function(every_n_steps_and_last_policy).set( + n=calculated_save_every_n_steps, + max_step=max_step, + ) + cfg.checkpointer.keep_every_n_steps = min(max_step, keep_every_n_steps) + cfg.checkpointer.keep_last_n = 3 + elif checkpointer == "OrbaxCheckpointer": + logging.info("In orbax checkpointer") + from axlearn.common.checkpointer_orbax import OrbaxCheckpointer + + ckpt_config: OrbaxCheckpointer.Config = ( + OrbaxCheckpointer.default_config() + ) + ckpt_config.save_policy = config_for_function(every_n_steps_and_last_policy).set( + n=calculated_save_every_n_steps, + max_step=max_step, + ) + ckpt_config.keep_last_n = 3 + cfg.checkpointer = ckpt_config + elif checkpointer == "OrbaxEmergencyCheckpointer": + # Prevent global dependency on Orbax. + # pylint: disable-next=import-outside-toplevel + from axlearn.common.checkpointer_orbax_emergency import OrbaxEmergencyCheckpointer + + ckpt_config: OrbaxEmergencyCheckpointer.Config = ( + OrbaxEmergencyCheckpointer.default_config() + ) + ckpt_config.save_policy = config_for_function(every_n_steps_and_last_policy).set( + # n=calculated_save_every_n_steps, + # Every 15 minures ore more recommended + n=200, + max_step=max_step, + ) + ckpt_config.local_save_policy = config_for_function(every_n_steps_and_last_policy).set( + # n=calculated_save_every_n_steps, + # Every 2 minutes or more generally recommended + n=30, + max_step=max_step, + ) + ckpt_config.local_dir = "/host-tmp/checkpoints" + ckpt_config.keep_every_n_steps = min(max_step, keep_every_n_steps) + ckpt_config.keep_last_n = 3 + ckpt_config.replica_axis_index = 1 + cfg.checkpointer = ckpt_config cfg.summary_writer.write_every_n_steps = min(eval_every_n_steps, 100) cfg.summary_writer.max_queue = 1000 if len(mesh_axis_names) != len(mesh_shape): diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 57c910c70..21ea1adb0 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -31,6 +31,7 @@ RoFormerQKVLinear, StackedTransformerLayer, ) +from absl import logging from axlearn.common.base_layer import RematSpec from axlearn.common.config import config_for_function from axlearn.common.decoder import LmHead @@ -366,6 +367,9 @@ def get_trainer_kwargs( ), ) elif model_size == "7B": + import jax + + gbs = len(jax.devices()) trainer_kwargs = dict( model_kwargs=dict( num_layers=32, @@ -378,7 +382,7 @@ def get_trainer_kwargs( ), learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, - train_batch_size=train_batch_size, + train_batch_size=gbs, max_step=max_step, mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), mesh_rules=( @@ -633,6 +637,9 @@ def get_trainer_kwargs( ), ) elif model_size == "70B": + import jax + + devices = len(jax.devices()) trainer_kwargs = dict( model_kwargs=dict( num_layers=80, @@ -648,7 +655,7 @@ def get_trainer_kwargs( ), learner_kwargs=dict(peak_lr=1.5e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, - train_batch_size=train_batch_size, + train_batch_size=devices*1, max_step=max_step, mesh_shape=mesh_shape_from_axes(fsdp=-1), mesh_rules=( @@ -914,22 +921,40 @@ def trainer_configs( """ arch = "fuji" config_map = {} - for version, model_size, flash_attention in itertools.product( - Version, MODEL_SIZES, [True, False] + orbax_options = [ + (True, False), # use_orbax_emergency_ckpt = True, use_orbax_ckpt = False + (False, True), # use_orbax_emergency_ckpt = False, use_orbax_ckpt = True + (False, False), # Neither is used + ] + for version, model_size, flash_attention, (use_orbax_emergency_ckpt, use_orbax_ckpt) in itertools.product( + Version, MODEL_SIZES, [True, False], orbax_options ): if model_size not in TOTAL_TOKENS[version]: # This combination does not exist. continue vocab_size = VOCAB_SIZE[version] + + current_suffix_parts = [] + if flash_attention: + current_suffix_parts.append("-flash") + if use_orbax_emergency_ckpt: + current_suffix_parts.append("-orbaxem") + elif use_orbax_ckpt: + current_suffix_parts.append("-orbax") + + current_suffix = "".join(current_suffix_parts) + logging.info(current_suffix) config_name = make_config_name( arch=arch, model_size=model_size, version=f"v{version.value}", - suffix="-flash" if flash_attention else "", + suffix=current_suffix, ) kwargs = get_trainer_kwargs( model_size, vocab_size=vocab_size, version=version, flash_attention=flash_attention ) max_sequence_length = kwargs.pop("max_sequence_length") + checkpointer_str = "OrbaxEmergencyCheckpointer" if use_orbax_emergency_ckpt else "" + checkpointer_str = "OrbaxCheckpointer" if use_orbax_ckpt else "" # pylint: disable-next=unexpected-keyword-arg,missing-kwoa config_map[config_name] = get_trainer_config_fn( train_input_source=train_input_source( @@ -939,6 +964,7 @@ def trainer_configs( evalers=evaler_config_dict( eval_input_sources(vocab_size=vocab_size, max_sequence_length=max_sequence_length), ), + checkpointer=checkpointer_str, **kwargs, )