Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions axlearn/cloud/gcp/jobset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,10 +508,10 @@ def _build_uploader_container(
dst = f"{cfg.output_dir}/output/$HOSTNAME/"
interval_s = 60
sync_command = f"while true; do gsutil -m rsync -r {src} {dst}; sleep {interval_s}; done"
resources = {
"requests": {"cpu": "100m", "memory": "128Mi"},
"limits": {"cpu": "500m", "memory": "256Mi"},
}
# resources = {
# "requests": {"cpu": "100m", "memory": "128Mi"},
# "limits": {"cpu": "500m", "memory": "256Mi"},
# }
return dict(
name="output-uploader",
image="google/cloud-sdk:alpine",
Expand All @@ -520,7 +520,7 @@ def _build_uploader_container(
restartPolicy="Always",
command=["/bin/sh", "-c"],
args=[sync_command],
resources=resources,
#resources=resources,
volumeMounts=[output_volume_mount],
)

Expand Down
9 changes: 9 additions & 0 deletions axlearn/common/checkpointer_orbax_emergency.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@

FLAGS = flags.FLAGS

def save_axlearn_checkpoint(step: int, state, directory: str, name: str):
cfg = Checkpointer.default_config().set(name=name, dir=directory)
ckpt = cfg.instantiate(parent=None)
ckpt.save(step=step, state=state)
ckpt.wait_until_finished()

@contextmanager
def setup(spec: str):
Expand Down Expand Up @@ -819,6 +824,10 @@ def restore(
)
time_diff = time.perf_counter() - start_t
logging.info("Took %ss to restore emergency checkpoint from %s.", time_diff, cfg.dir)

logging.info("Saving a non-Orbax checkpoint from the restored Orbax state...")
save_axlearn_checkpoint(step, restored_state, cfg.dir, cfg.name)

return step, restored_state

def wait_until_finished(self):
Expand Down
41 changes: 35 additions & 6 deletions axlearn/experiments/text/gpt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,7 @@ def get_trainer_config_fn(
keep_every_n_steps: int = 50_000,
save_every_n_steps: Optional[int] = None,
init_state_builder: Optional[state_builder.Builder.Config] = None,
checkpointer: str = "",
) -> TrainerConfigFn:
"""Builds a TrainerConfigFn according to the model and input specs.

Expand Down Expand Up @@ -710,12 +711,40 @@ def config_fn() -> InstantiableConfig:
)
cfg.evalers[name] = evaler_cfg
# Summaries and checkpoints.
cfg.checkpointer.save_policy = config_for_function(every_n_steps_and_last_policy).set(
n=save_every_n_steps or min(eval_every_n_steps, 5_000),
max_step=max_step,
)
cfg.checkpointer.keep_every_n_steps = min(max_step, keep_every_n_steps)
cfg.checkpointer.keep_last_n = 3
calculated_save_every_n_steps = save_every_n_steps or min(eval_every_n_steps, 500)

if not checkpointer:
cfg.checkpointer.save_policy = config_for_function(every_n_steps_and_last_policy).set(
n=calculated_save_every_n_steps,
max_step=max_step,
)
cfg.checkpointer.keep_every_n_steps = min(max_step, keep_every_n_steps)
cfg.checkpointer.keep_last_n = 3
elif checkpointer == "OrbaxEmergencyCheckpointer":
# Prevent global dependency on Orbax.
# pylint: disable-next=import-outside-toplevel
from axlearn.common.checkpointer_orbax_emergency import OrbaxEmergencyCheckpointer

ckpt_config: OrbaxEmergencyCheckpointer.Config = (
OrbaxEmergencyCheckpointer.default_config()
)
ckpt_config.save_policy = config_for_function(every_n_steps_and_last_policy).set(
# n=calculated_save_every_n_steps,
# Every 15 minures ore more recommended
n=200,
max_step=max_step,
)
ckpt_config.local_save_policy = config_for_function(every_n_steps_and_last_policy).set(
# n=calculated_save_every_n_steps,
# Every 2 minutes or more generally recommended
n=30,
max_step=max_step,
)
ckpt_config.local_dir = "/host-tmp/checkpoints"
ckpt_config.keep_every_n_steps = min(max_step, keep_every_n_steps)
ckpt_config.keep_last_n = 3
ckpt_config.replica_axis_index = 1
cfg.checkpointer = ckpt_config
cfg.summary_writer.write_every_n_steps = min(eval_every_n_steps, 100)
cfg.summary_writer.max_queue = 1000
if len(mesh_axis_names) != len(mesh_shape):
Expand Down
25 changes: 20 additions & 5 deletions axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,9 @@ def get_trainer_kwargs(
),
)
elif model_size == "7B":
import jax

gbs = len(jax.devices())
trainer_kwargs = dict(
model_kwargs=dict(
num_layers=32,
Expand All @@ -378,7 +381,7 @@ def get_trainer_kwargs(
),
learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1),
max_sequence_length=max_sequence_length,
train_batch_size=train_batch_size,
train_batch_size=gbs,
max_step=max_step,
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8),
mesh_rules=(
Expand Down Expand Up @@ -633,6 +636,9 @@ def get_trainer_kwargs(
),
)
elif model_size == "70B":
import jax

devices = len(jax.devices())
trainer_kwargs = dict(
model_kwargs=dict(
num_layers=80,
Expand All @@ -648,7 +654,7 @@ def get_trainer_kwargs(
),
learner_kwargs=dict(peak_lr=1.5e-4, weight_decay=0.1),
max_sequence_length=max_sequence_length,
train_batch_size=train_batch_size,
train_batch_size=devices*1,
max_step=max_step,
mesh_shape=mesh_shape_from_axes(fsdp=-1),
mesh_rules=(
Expand Down Expand Up @@ -914,22 +920,30 @@ def trainer_configs(
"""
arch = "fuji"
config_map = {}
for version, model_size, flash_attention in itertools.product(
Version, MODEL_SIZES, [True, False]
for version, model_size, flash_attention, use_orbax_emergency_ckpt in itertools.product(
Version, MODEL_SIZES, [True, False], [False, True]
):
if model_size not in TOTAL_TOKENS[version]: # This combination does not exist.
continue
vocab_size = VOCAB_SIZE[version]

current_suffix_parts = []
if flash_attention:
current_suffix_parts.append("-flash")
if use_orbax_emergency_ckpt:
current_suffix_parts.append("-orbaxem")
current_suffix = "".join(current_suffix_parts)
config_name = make_config_name(
arch=arch,
model_size=model_size,
version=f"v{version.value}",
suffix="-flash" if flash_attention else "",
suffix=current_suffix,
)
kwargs = get_trainer_kwargs(
model_size, vocab_size=vocab_size, version=version, flash_attention=flash_attention
)
max_sequence_length = kwargs.pop("max_sequence_length")
checkpointer_str = "OrbaxEmergencyCheckpointer" if use_orbax_emergency_ckpt else ""
# pylint: disable-next=unexpected-keyword-arg,missing-kwoa
config_map[config_name] = get_trainer_config_fn(
train_input_source=train_input_source(
Expand All @@ -939,6 +953,7 @@ def trainer_configs(
evalers=evaler_config_dict(
eval_input_sources(vocab_size=vocab_size, max_sequence_length=max_sequence_length),
),
checkpointer=checkpointer_str,
**kwargs,
)

Expand Down