Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/MaxText/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import orbax.checkpoint as ocp
from orbax.checkpoint import v1 as ocp_v1
from orbax.checkpoint._src.arrays import sharding as sharding_utils
from orbax.checkpoint._src.checkpoint_managers import preservation_policy as preservation_policy_lib
from orbax.checkpoint._src.checkpoint_managers import save_decision_policy as save_decision_policy_lib
import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager
import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager
# pylint: disable=too-many-positional-arguments
Expand Down Expand Up @@ -185,6 +187,8 @@ def create_orbax_checkpoint_manager(
orbax_logger: Any = None, # pytype: disable=attribute-error
use_ocdbt: bool = True,
use_zarr3: bool = True,
enable_continuous_checkpointing: bool = True,
keep_last_n_checkpoints: int = 10,
):
"""Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled."""
if not enable_checkpointing:
Expand All @@ -205,6 +209,15 @@ def create_orbax_checkpoint_manager(
# local storage checkpoint needs parent directory created
p = epath.Path(checkpoint_dir)
p.mkdir(exist_ok=True, parents=True)
if enable_continuous_checkpointing:
save_decision_policy = save_decision_policy_lib.ContinuousCheckpointingPolicy()
preservation_policy = preservation_policy_lib.AnyPreservationPolicy(
[preservation_policy_lib.LatestN(keep_last_n_checkpoints),
preservation_policy_lib.EveryNSteps(save_interval_steps)]
)
else:
save_decision_policy = None
preservation_policy = None
manager = CheckpointManager(
p,
item_names=item_names,
Expand All @@ -213,6 +226,9 @@ def create_orbax_checkpoint_manager(
create=True,
save_interval_steps=save_interval_steps,
enable_async_checkpointing=use_async,
save_decision_policy=save_decision_policy,
preservation_policy=preservation_policy,
),
),
logger=orbax_logger,
)
Expand Down
2 changes: 2 additions & 0 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ save_checkpoint_on_completion: True
async_checkpointing: True
checkpoint_period: 10_000
max_num_checkpoints_to_keep: None
enable_continuous_checkpointing: True
keep_last_n_checkpoints: 10
# enables one replica to read the ckpt then broadcast to the rest
enable_single_replica_ckpt_restoring: False

Expand Down
2 changes: 2 additions & 0 deletions src/MaxText/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def create_training_tools(config, model, mesh):
logger,
use_ocdbt,
use_zarr3,
config.enable_continuous_checkpointing,
config.keep_last_n_checkpoints,
)

return init_rng, checkpoint_manager, learning_rate_schedule, tx
Expand Down
Loading