Skip to content

Commit bf1d89e

Browse files
committed
Adding 70B config for v5e-32 but it has an HBM OOM
1 parent 73e38e4 commit bf1d89e

File tree

1 file changed

+19
-0
lines changed
  • axlearn/experiments/text/gpt

1 file changed

+19
-0
lines changed

axlearn/experiments/text/gpt/fuji.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,25 @@ def get_trainer_kwargs(
673673
mesh_rules=(
674674
# TPU V5e maximum per device batch is 1.
675675
# with all activation offloading, HBM usage: 14.6GB/chip.
676+
# tpu-v5e-32-1
677+
(
678+
"tpu-v5litepod-32-1",
679+
ChainConfigModifier.default_config().set(
680+
config_modifiers=[
681+
MeshShapeModifier.default_config().set(
682+
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=32)
683+
),
684+
RematSpecModifier.default_config().set(
685+
remat_policies={
686+
"model.decoder.transformer.layer": RematSpec(
687+
prevent_cse=False,
688+
policy=offload_dots_saveable_policy,
689+
),
690+
}
691+
),
692+
],
693+
),
694+
),
676695
# TODO(kelvin-zou): Fix the env issue for internal use cases.
677696
# tpu-v5e-256-4. step time: 14.3736s (59.87% MFU).
678697
(

0 commit comments

Comments
 (0)