Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
6005b6d
add test script
samos123 Jun 20, 2025
40d1ed5
Jun Orbax regular checkpointer fixes
samos123 Jun 20, 2025
b9194c1
Orbax emergency trainer config for Fuji
samos123 May 22, 2025
74f7522
update fuji configs to use regular orbax checkpointer
samos123 Jun 20, 2025
8e80931
support keep_period
samos123 Jun 20, 2025
10d0576
support run for orbax regular checkpointer
samos123 Jun 20, 2025
abbe0ac
fix for A TypeHandler for "<class 'jax.Array'>" is already registered.
samos123 Jun 20, 2025
1292819
pdbs=1 and print every step
samos123 Jun 20, 2025
073cbfa
checkpoint every 100 steps
samos123 Jun 20, 2025
d723cc7
increase termination Grace Period to 300s
samos123 Jun 20, 2025
383d0bc
termination grace period to 900s
samos123 Jun 20, 2025
90f1e27
Revert "termination grace period to 900s"
samos123 Jun 20, 2025
871d2db
save the data iterator
samos123 Jun 25, 2025
bd25cff
use fuji v3 8b for tiktoken
samos123 Jul 21, 2025
8923ca6
use tokenizers instead of tokenizer
samos123 Jul 28, 2025
03c3fb6
disable saving of data iterator
samos123 Jul 28, 2025
410134d
use orbaxem
samos123 Jul 29, 2025
9051f90
add needed fix for orbax em
samos123 Jul 29, 2025
b10878f
enable orbax debug logging
samos123 Jul 29, 2025
bc3cbfc
sort the to be assigned keys and available process indexes
samos123 Jul 30, 2025
f8000d7
sort proc_infos as well
samos123 Jul 30, 2025
e752a12
Revert "sort proc_infos as well"
samos123 Jul 30, 2025
c7bdd39
Revert "sort the to be assigned keys and available process indexes"
samos123 Jul 30, 2025
13cb8ac
gemini fix?
samos123 Jul 30, 2025
ce3d7ad
fail fast
findmyway Jul 30, 2025
6223a3f
Revert "gemini fix?"
samos123 Jul 30, 2025
879afcb
Merge branch 'orbax-fuji-v2' of github.com:samos123/axlearn into orba…
samos123 Jul 30, 2025
1fcb62b
switch orbax with Jun's patch
samos123 Jul 30, 2025
ad4a498
add large scale config
samos123 Jul 31, 2025
3be33cd
enable BlockingRecreate
samos123 Jul 31, 2025
c4f20f2
disable BlockingRecreate on jobset
samos123 Jul 31, 2025
0d9513c
switch cluster
samos123 Aug 3, 2025
c4f8766
remove debug logging
samos123 Aug 3, 2025
0870285
add script to force delete pods
samos123 Aug 3, 2025
0a079a1
use BlockingRecreate
samos123 Aug 3, 2025
227c94b
add goodput recorder
samos123 Aug 4, 2025
0c2ce00
Integrate AXLearn with latest Goodput package
dipannita08 Jul 25, 2025
799c4a9
switch command to new goodput library
samos123 Aug 4, 2025
a138a9f
print log every step
samos123 Aug 4, 2025
5cc91ee
force deletion correctly for terminating pods
samos123 Aug 4, 2025
ecbfa95
switch to debug cluster gcs bucket
samos123 Aug 5, 2025
eab4ed5
bump orbax em fork
samos123 Aug 5, 2025
9318bcd
turn off goodput logging since it needs more permissions
samos123 Aug 5, 2025
2b11540
Revert "Integrate AXLearn with latest Goodput package"
samos123 Aug 5, 2025
f888f90
print jax_devices
samos123 Aug 6, 2025
8ef0220
fsdp=256 data=-1 so ici_dp=1
samos123 Aug 6, 2025
2c2e134
use latest main of orbax
samos123 Aug 7, 2025
8b3a425
70b fsdp=64,data=-1
samos123 Aug 7, 2025
ea6fc62
fsdp=32 with 70b
samos123 Aug 7, 2025
5f70377
70b fsdp=64
samos123 Aug 7, 2025
e33e04a
bump orbax to 0.11.21
samos123 Aug 7, 2025
dd14462
7b fsdp=16
samos123 Aug 7, 2025
5b1401a
use jun's fix for fsdp=16 data=16
samos123 Aug 8, 2025
b3f67b9
fsdp=256 7b
samos123 Aug 8, 2025
09df9e9
use orbax em with single replica GCS restore
samos123 Aug 8, 2025
edd48ee
70b fsdp=64
samos123 Aug 8, 2025
ff64d65
try with new orbax fix
samos123 Aug 9, 2025
056e0ea
update orbax
samos123 Aug 16, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ ENTRYPOINT ["/opt/apache/beam/boot"]

FROM base AS tpu

ARG EXTRAS=
ARG EXTRAS=orbax

ENV UV_FIND_LINKS=https://storage.googleapis.com/jax-releases/libtpu_releases.html
# Ensure we install the TPU version, even if building locally.
Expand Down
6 changes: 5 additions & 1 deletion axlearn/cloud/gcp/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,11 @@ def _build_jobset(self) -> Nested[Any]:
return dict(
metadata=dict(name=cfg.name, annotations=annotations),
spec=dict(
failurePolicy=dict(maxRestarts=cfg.max_tries - 1),
failurePolicy=dict(
maxRestarts=cfg.max_tries - 1,
restartStrategy="BlockingRecreate"
# maxRestarts=cfg.max_tries - 1,
),
replicatedJobs=self._builder(),
),
)
Expand Down
5 changes: 4 additions & 1 deletion axlearn/cloud/gcp/jobset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,9 @@ def _build_container(self) -> Nested[Any]:
if cfg.enable_tpu_ici_resiliency is not None:
env_vars["ENABLE_ICI_RESILIENCY"] = str(cfg.enable_tpu_ici_resiliency).lower()

env_vars["TPU_PREMAPPED_BUFFER_SIZE"] = "137438953472"
env_vars["TPU_PREMAPPED_BUFFER_TRANSFER_THRESHOLD_BYTES"] = "137438953472"

resources = {"limits": {"google.com/tpu": system.chips_per_vm}}
# Set request memory by host machine type.
machine_memory_gi = GCE_MACHINE_TYPE_TO_MEMORY_CHARACTERISTICS.get(
Expand Down Expand Up @@ -690,7 +693,7 @@ def _build_pod(self) -> Nested[Any]:

spec = dict(
# NOTE: Don't set hostNetwork or dnsPolicy for compat with Workload Identity.
terminationGracePeriodSeconds=60,
terminationGracePeriodSeconds=300,
# Fail if any pod fails, and allow retries to happen at JobSet level.
restartPolicy="Never",
# https://kubernetes.io/docs/tasks/network/customize-hosts-file-for-pods/#adding-additional-entries-with-hostaliases
Expand Down
68 changes: 65 additions & 3 deletions axlearn/common/checkpointer_orbax.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import jax
import numpy as np
import orbax.checkpoint as ocp
import tensorflow as tf
from absl import logging
from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib
from orbax.checkpoint._src.serialization.type_handlers import ArrayHandler

from axlearn.common import utils
from axlearn.common.checkpointer import (
Expand Down Expand Up @@ -187,15 +190,18 @@ class Config(BaseCheckpointer.Config):

Attributes:
keep_last_n: Keep this many past ckpts.
keep_every_n_steps: If set, keep a checkpoint every n steps.
validation_type: Checkpoint validation during restore.
async_timeout_secs: Timeout for async barrier in seconds.
"""

keep_last_n: int = 1
keep_every_n_steps: Optional[int] = None
validation_type: CheckpointValidationType = CheckpointValidationType.EXACT
async_timeout_secs: int = 300
max_concurrent_save_gb: Optional[int] = None
max_concurrent_restore_gb: Optional[int] = None
enable_single_replica_ckpt_restoring: bool = True

@classmethod
def checkpoint_paths(cls, base_dir: str) -> List[str]:
Expand Down Expand Up @@ -237,6 +243,7 @@ def save_fn_with_summaries(step: int, last_saved_step: Optional[int]) -> bool:
options=ocp.CheckpointManagerOptions(
create=True,
max_to_keep=cfg.keep_last_n,
keep_period=cfg.keep_every_n_steps,
enable_async_checkpointing=True,
step_name_format=self._name_format,
should_save_fn=save_fn_with_summaries,
Expand Down Expand Up @@ -321,11 +328,33 @@ def restore(

cfg: OrbaxCheckpointer.Config = self.config

if cfg.enable_single_replica_ckpt_restoring:
array_handler = ocp.type_handlers.SingleReplicaArrayHandler(
replica_axis_index=0,
broadcast_memory_limit_bytes=1024 * 1024 * 1000, # 1000 MB limit
)
ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True)

def _restore_args(x: Any) -> ocp.RestoreArgs:
if isinstance(x, (Tensor, TensorSpec)):
return ocp.checkpoint_utils.construct_restore_args(
jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding)
)
if cfg.enable_single_replica_ckpt_restoring:
pspec = x.sharding.spec
mesh = x.sharding.mesh
replica_axis_index = 0
replica_devices = _replica_devices(mesh.devices, replica_axis_index)
replica_mesh = jax.sharding.Mesh(replica_devices, mesh.axis_names)
single_replica_sharding = jax.sharding.NamedSharding(replica_mesh, pspec)

return ocp.type_handlers.SingleReplicaArrayRestoreArgs(
sharding=jax.sharding.NamedSharding(mesh, pspec),
single_replica_sharding=single_replica_sharding,
global_shape=x.shape,
dtype=x.dtype,
)
else:
return ocp.checkpoint_utils.construct_restore_args(
jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding)
)
elif isinstance(x, tf.data.Iterator):
return _TfIteratorHandler.RestoreArgs(item=x)
elif _GRAIN_INSTALLED and isinstance(x, _GrainIterator):
Expand All @@ -349,6 +378,13 @@ def _restore_args(x: Any) -> ocp.RestoreArgs:
raise ValueError(f"Failed to restore at step {step}.") from e
logging.info("Could not find any completed checkpoints under %s: %s", cfg.dir, e)
return None, state # Return the input state.
finally:
if cfg.enable_single_replica_ckpt_restoring:
ocp.type_handlers.register_type_handler(
jax.Array,
ArrayHandler(array_metadata_store=array_metadata_store_lib.Store()),
override=True,
)

restored_index = composite_state["index"]
restored_state = composite_state["state"]
Expand All @@ -375,3 +411,29 @@ def wait_until_finished(self):
def stop(self, *, has_exception: bool = False):
"""See `BaseCheckpointer.stop` for details."""
self._manager.close()


def _find_idx(array: np.ndarray, replica_axis_idx: int):
"""Returns the index along given dimension that the current host belongs to."""
idx = None
for idx, val in np.ndenumerate(array):
if val.process_index == jax.process_index():
break
return idx[replica_axis_idx]


def _replica_devices(device_array: np.ndarray, replica_axis_idx: int):
"""Returns the devices from the replica that current host belongs to.

Replicas are assumed to be restricted to the first axis.

Args:
device_array: devices of the mesh that can be obtained by mesh.devices()
replica_axis_idx: axis dimension along which replica is taken

Returns:
devices inside the replica that current host is in
"""
idx = _find_idx(device_array, replica_axis_idx)
replica_result = np.take(device_array, idx, axis=replica_axis_idx)
return np.expand_dims(replica_result, axis=replica_axis_idx)
5 changes: 5 additions & 0 deletions axlearn/common/checkpointer_orbax_emergency.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def setup(spec: str):
FLAGS.process_id = info.inv_proc_id
FLAGS.distributed_coordinator = info.address
FLAGS.experimental_orbax_use_distributed_process_id = True
# Required for case when slices swap and ici_dp=2 or higher.
# PR that introduced this flag: https://github.com/google/orbax/pull/2222
FLAGS.experimental_use_distributed_id_for_mesh_consistency = False
yield


Expand Down Expand Up @@ -314,6 +317,8 @@ def _init_consistent_proc_ids(
# Then, rank 0 assigns inv_proc_id for worker that's missing their inv_proc_id and find the
# coordinator address.
if local_proc_info.cur_proc_id == 0:
jax_devices = [dev.id for dev in jax.devices()]
logging.info("jax_devices=%s", jax_devices)
ids = client.key_value_dir_get(key_prefix)
proc_infos: list[_ProcessInfo] = []

Expand Down
7 changes: 6 additions & 1 deletion axlearn/common/compiler_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,12 @@ def default_xla_options(
# cause the step time to be double. You should increase this
# further if you see "Allocator failed to allocate". A feature
# to dynamically allocate may come later: b/380514965
megascale_grpc_premap_memory_bytes=17179869184,
# Needed for orbax emergency checkpointer
# Needed for decent perf
megascale_grpc_premap_memory_bytes=137438953472,
# needed for restore consistent hash
megascale_jax_offset_launch_id_by_module_name=True,
megascale_jax_use_device_set_based_launch_id=False,
# Flag controlling the maximum number of overlapping host offloadings.
xla_tpu_host_transfer_overlap_limit=24,
# Flag controlling the maximum number of overlapping cross-DCN send/recv.
Expand Down
70 changes: 34 additions & 36 deletions axlearn/common/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ def run(
)
self.vlog(3, "Done step %s", self.step)
num_steps += 1
if num_steps % 100 == 0:
if num_steps % 1 == 0:
now = time.perf_counter()
average_step_time = (now - start_time) / num_steps
self._step_log("Average step time: %s seconds", average_step_time)
Expand Down Expand Up @@ -927,42 +927,40 @@ def restore_checkpoint(self, restore_step: Optional[int] = None) -> Optional[int
**ckpt_state_spec, input_iter=iter(self.input.dataset())
)
restore_input_iter = cfg.save_input_iterator
try:
# Try to restore with `input_iter`.
step, ckpt_state = self.checkpointer.restore(
step=restore_step,
state=(
ckpt_state_spec_with_input_iter if restore_input_iter else ckpt_state_spec
),
)
if step is not None:
self.vlog(
0,
"Restored checkpoint at %s with restore_input_iter=%s",
step,
restore_input_iter,
)
except ValueError as e:
logging.warning(
"Attempt to restore checkpoint with restore_input_iter=%s failed: %s",
# try:
# Try to restore with `input_iter`.
step, ckpt_state = self.checkpointer.restore(
step=restore_step,
state=(ckpt_state_spec_with_input_iter if restore_input_iter else ckpt_state_spec),
)
if step is not None:
self.vlog(
0,
"Restored checkpoint at %s with restore_input_iter=%s",
step,
restore_input_iter,
e,
)
# Restore with a different restore_input_iter setting.
restore_input_iter = not restore_input_iter
step, ckpt_state = self.checkpointer.restore(
step=restore_step,
state=(
ckpt_state_spec_with_input_iter if restore_input_iter else ckpt_state_spec
),
)
if step is not None:
self.vlog(
0,
"Restored checkpoint at %s with restore_input_iter=%s",
step,
restore_input_iter,
)
# except ValueError as e:
# logging.warning(
# "Attempt to restore checkpoint with restore_input_iter=%s failed: %s",
# restore_input_iter,
# e,
# )
# # Restore with a different restore_input_iter setting.
# restore_input_iter = not restore_input_iter
# step, ckpt_state = self.checkpointer.restore(
# step=restore_step,
# state=(
# ckpt_state_spec_with_input_iter if restore_input_iter else ckpt_state_spec
# ),
# )
# if step is not None:
# self.vlog(
# 0,
# "Restored checkpoint at %s with restore_input_iter=%s",
# step,
# restore_input_iter,
# )
if step is not None:
self._step = step
self._trainer_state = TrainerState(
Expand Down Expand Up @@ -1097,7 +1095,7 @@ def _run_step(
# Run the compiled function.
self._trainer_state, outputs = compiled_train_step_fn(self.trainer_state, input_batch)

if self.step % 100 == 0 or 0 <= self.step <= 5:
if self.step % 1 == 0 or 0 <= self.step <= 5:
self._step_log(
"loss=%s aux=%s",
outputs["loss"],
Expand Down
56 changes: 50 additions & 6 deletions axlearn/experiments/text/gpt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,7 @@ def get_trainer_config_fn(
keep_every_n_steps: int = 50_000,
save_every_n_steps: Optional[int] = None,
init_state_builder: Optional[state_builder.Builder.Config] = None,
checkpointer: str = "",
) -> TrainerConfigFn:
"""Builds a TrainerConfigFn according to the model and input specs.

Expand Down Expand Up @@ -709,13 +710,56 @@ def config_fn() -> InstantiableConfig:
)
)
cfg.evalers[name] = evaler_cfg

# Summaries and checkpoints.
cfg.checkpointer.save_policy = config_for_function(every_n_steps_and_last_policy).set(
n=save_every_n_steps or min(eval_every_n_steps, 5_000),
max_step=max_step,
)
cfg.checkpointer.keep_every_n_steps = min(max_step, keep_every_n_steps)
cfg.checkpointer.keep_last_n = 3
calculated_save_every_n_steps = save_every_n_steps or min(eval_every_n_steps, 100)

if not checkpointer:
cfg.checkpointer.save_policy = config_for_function(every_n_steps_and_last_policy).set(
n=calculated_save_every_n_steps,
max_step=max_step,
)
cfg.checkpointer.keep_every_n_steps = min(max_step, keep_every_n_steps)
cfg.checkpointer.keep_last_n = 3
elif checkpointer == "OrbaxEmergencyCheckpointer":
# Prevent global dependency on Orbax.
# pylint: disable-next=import-outside-toplevel
from axlearn.common.checkpointer_orbax_emergency import OrbaxEmergencyCheckpointer

ckpt_config: OrbaxEmergencyCheckpointer.Config = (
OrbaxEmergencyCheckpointer.default_config()
)
ckpt_config.save_policy = config_for_function(every_n_steps_and_last_policy).set(
n=calculated_save_every_n_steps,
max_step=max_step,
)
ckpt_config.local_save_policy = config_for_function(every_n_steps_and_last_policy).set(
n=20,
max_step=max_step,
)
ckpt_config.local_dir = "/host-tmp/checkpoints"
ckpt_config.keep_every_n_steps = min(max_step, keep_every_n_steps)
ckpt_config.keep_last_n = 3
ckpt_config.replica_axis_index = 1
cfg.checkpointer = ckpt_config
elif checkpointer == "OrbaxRegularCheckpointer":
# Prevent global dependency on Orbax.
# pylint: disable-next=import-outside-toplevel
from axlearn.common.checkpointer_orbax import OrbaxCheckpointer

ckpt_config: OrbaxCheckpointer.Config = OrbaxCheckpointer.default_config()
ckpt_config.save_policy = config_for_function(every_n_steps_and_last_policy).set(
n=calculated_save_every_n_steps,
max_step=max_step,
)
ckpt_config.keep_every_n_steps = min(max_step, keep_every_n_steps)
ckpt_config.keep_last_n = 3
cfg.checkpointer = ckpt_config

# Save the data iterator as part of the checkpointing process.
# default is false.
# cfg.save_input_iterator = True

cfg.summary_writer.write_every_n_steps = min(eval_every_n_steps, 100)
cfg.summary_writer.max_queue = 1000
if len(mesh_axis_names) != len(mesh_shape):
Expand Down
Loading
Loading