Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions axlearn/common/checkpointer_orbax_emergency.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@

FLAGS = flags.FLAGS

def save_axlearn_checkpoint(step: int, state, directory: str, name: str):
cfg = Checkpointer.default_config().set(name=name, dir=directory)
ckpt = cfg.instantiate(parent=None)
ckpt.save(step=step, state=state)
ckpt.wait_until_finished()

@contextmanager
def setup(spec: str):
Expand Down Expand Up @@ -819,6 +824,10 @@ def restore(
)
time_diff = time.perf_counter() - start_t
logging.info("Took %ss to restore emergency checkpoint from %s.", time_diff, cfg.dir)

logging.info("Saving an AXLearn tensorstore from the restored Orbax state...")
save_axlearn_checkpoint(step, restored_state, cfg.dir, cfg.name)
Comment on lines +828 to +829
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand it correctly, here we're still using an online approach to do the checkpoint conversion. This means that we have to allocate the same resource (or at least a slice) of training stage for each checkpoint conversion.

I'm wondering if we can do the conversion offline on a CPU only node with a large memory.


return step, restored_state

def wait_until_finished(self):
Expand Down
41 changes: 35 additions & 6 deletions axlearn/experiments/text/gpt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,7 @@ def get_trainer_config_fn(
keep_every_n_steps: int = 50_000,
save_every_n_steps: Optional[int] = None,
init_state_builder: Optional[state_builder.Builder.Config] = None,
checkpointer: str = "",
) -> TrainerConfigFn:
"""Builds a TrainerConfigFn according to the model and input specs.

Expand Down Expand Up @@ -710,12 +711,40 @@ def config_fn() -> InstantiableConfig:
)
cfg.evalers[name] = evaler_cfg
# Summaries and checkpoints.
cfg.checkpointer.save_policy = config_for_function(every_n_steps_and_last_policy).set(
n=save_every_n_steps or min(eval_every_n_steps, 5_000),
max_step=max_step,
)
cfg.checkpointer.keep_every_n_steps = min(max_step, keep_every_n_steps)
cfg.checkpointer.keep_last_n = 3
calculated_save_every_n_steps = save_every_n_steps or min(eval_every_n_steps, 500)

if not checkpointer:
cfg.checkpointer.save_policy = config_for_function(every_n_steps_and_last_policy).set(
n=calculated_save_every_n_steps,
max_step=max_step,
)
cfg.checkpointer.keep_every_n_steps = min(max_step, keep_every_n_steps)
cfg.checkpointer.keep_last_n = 3
elif checkpointer == "OrbaxEmergencyCheckpointer":
# Prevent global dependency on Orbax.
# pylint: disable-next=import-outside-toplevel
from axlearn.common.checkpointer_orbax_emergency import OrbaxEmergencyCheckpointer

ckpt_config: OrbaxEmergencyCheckpointer.Config = (
OrbaxEmergencyCheckpointer.default_config()
)
ckpt_config.save_policy = config_for_function(every_n_steps_and_last_policy).set(
# n=calculated_save_every_n_steps,
# Every 15 minures ore more recommended
n=200,
max_step=max_step,
)
ckpt_config.local_save_policy = config_for_function(every_n_steps_and_last_policy).set(
# n=calculated_save_every_n_steps,
# Every 2 minutes or more generally recommended
n=30,
max_step=max_step,
)
ckpt_config.local_dir = "/host-tmp/checkpoints"
ckpt_config.keep_every_n_steps = min(max_step, keep_every_n_steps)
ckpt_config.keep_last_n = 3
ckpt_config.replica_axis_index = 1
cfg.checkpointer = ckpt_config
cfg.summary_writer.write_every_n_steps = min(eval_every_n_steps, 100)
cfg.summary_writer.max_queue = 1000
if len(mesh_axis_names) != len(mesh_shape):
Expand Down
25 changes: 20 additions & 5 deletions axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,9 @@ def get_trainer_kwargs(
),
)
elif model_size == "7B":
import jax

gbs = len(jax.devices())
trainer_kwargs = dict(
model_kwargs=dict(
num_layers=32,
Expand All @@ -378,7 +381,7 @@ def get_trainer_kwargs(
),
learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1),
max_sequence_length=max_sequence_length,
train_batch_size=train_batch_size,
train_batch_size=gbs,
max_step=max_step,
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8),
mesh_rules=(
Expand Down Expand Up @@ -633,6 +636,9 @@ def get_trainer_kwargs(
),
)
elif model_size == "70B":
import jax

devices = len(jax.devices())
trainer_kwargs = dict(
model_kwargs=dict(
num_layers=80,
Expand All @@ -648,7 +654,7 @@ def get_trainer_kwargs(
),
learner_kwargs=dict(peak_lr=1.5e-4, weight_decay=0.1),
max_sequence_length=max_sequence_length,
train_batch_size=train_batch_size,
train_batch_size=devices*1,
max_step=max_step,
mesh_shape=mesh_shape_from_axes(fsdp=-1),
mesh_rules=(
Expand Down Expand Up @@ -914,22 +920,30 @@ def trainer_configs(
"""
arch = "fuji"
config_map = {}
for version, model_size, flash_attention in itertools.product(
Version, MODEL_SIZES, [True, False]
for version, model_size, flash_attention, use_orbax_emergency_ckpt in itertools.product(
Version, MODEL_SIZES, [True, False], [False, True]
):
if model_size not in TOTAL_TOKENS[version]: # This combination does not exist.
continue
vocab_size = VOCAB_SIZE[version]

current_suffix_parts = []
if flash_attention:
current_suffix_parts.append("-flash")
if use_orbax_emergency_ckpt:
current_suffix_parts.append("-orbaxem")
current_suffix = "".join(current_suffix_parts)
config_name = make_config_name(
arch=arch,
model_size=model_size,
version=f"v{version.value}",
suffix="-flash" if flash_attention else "",
suffix=current_suffix,
)
kwargs = get_trainer_kwargs(
model_size, vocab_size=vocab_size, version=version, flash_attention=flash_attention
)
max_sequence_length = kwargs.pop("max_sequence_length")
checkpointer_str = "OrbaxEmergencyCheckpointer" if use_orbax_emergency_ckpt else ""
# pylint: disable-next=unexpected-keyword-arg,missing-kwoa
config_map[config_name] = get_trainer_config_fn(
train_input_source=train_input_source(
Expand All @@ -939,6 +953,7 @@ def trainer_configs(
evalers=evaler_config_dict(
eval_input_sources(vocab_size=vocab_size, max_sequence_length=max_sequence_length),
),
checkpointer=checkpointer_str,
**kwargs,
)

Expand Down
Loading