diff --git a/src/MaxText/checkpointing.py b/src/MaxText/checkpointing.py index 7e9df4802..154334534 100644 --- a/src/MaxText/checkpointing.py +++ b/src/MaxText/checkpointing.py @@ -30,6 +30,8 @@ import orbax.checkpoint as ocp from orbax.checkpoint import v1 as ocp_v1 from orbax.checkpoint._src.arrays import sharding as sharding_utils +from orbax.checkpoint._src.checkpoint_managers import preservation_policy as preservation_policy_lib +from orbax.checkpoint._src.checkpoint_managers import save_decision_policy as save_decision_policy_lib import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager # pylint: disable=too-many-positional-arguments @@ -185,6 +187,8 @@ def create_orbax_checkpoint_manager( orbax_logger: Any = None, # pytype: disable=attribute-error use_ocdbt: bool = True, use_zarr3: bool = True, + enable_continuous_checkpointing: bool = True, + keep_last_n_checkpoints: int = 10, ): """Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled.""" if not enable_checkpointing: @@ -205,6 +209,15 @@ def create_orbax_checkpoint_manager( # local storage checkpoint needs parent directory created p = epath.Path(checkpoint_dir) p.mkdir(exist_ok=True, parents=True) + if enable_continuous_checkpointing: + save_decision_policy = save_decision_policy_lib.ContinuousCheckpointingPolicy() + preservation_policy = preservation_policy_lib.AnyPreservationPolicy( + [preservation_policy_lib.LatestN(keep_last_n_checkpoints), + preservation_policy_lib.EveryNSteps(save_interval_steps)] + ) + else: + save_decision_policy = None + preservation_policy = None manager = CheckpointManager( p, item_names=item_names, @@ -213,6 +226,9 @@ def create_orbax_checkpoint_manager( create=True, save_interval_steps=save_interval_steps, enable_async_checkpointing=use_async, + save_decision_policy=save_decision_policy, + preservation_policy=preservation_policy, + ), ), logger=orbax_logger, ) diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index de05461d9..f68ccd1a3 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -51,6 +51,8 @@ save_checkpoint_on_completion: True async_checkpointing: True checkpoint_period: 10_000 max_num_checkpoints_to_keep: None +enable_continuous_checkpointing: True +keep_last_n_checkpoints: 10 # enables one replica to read the ckpt then broadcast to the rest enable_single_replica_ckpt_restoring: False diff --git a/src/MaxText/train_utils.py b/src/MaxText/train_utils.py index ce465a56e..29e16a4ff 100644 --- a/src/MaxText/train_utils.py +++ b/src/MaxText/train_utils.py @@ -73,6 +73,8 @@ def create_training_tools(config, model, mesh): logger, use_ocdbt, use_zarr3, + config.enable_continuous_checkpointing, + config.keep_last_n_checkpoints, ) return init_rng, checkpoint_manager, learning_rate_schedule, tx