File tree Expand file tree Collapse file tree 1 file changed +19
-0
lines changed
axlearn/experiments/text/gpt Expand file tree Collapse file tree 1 file changed +19
-0
lines changed Original file line number Diff line number Diff line change @@ -673,6 +673,25 @@ def get_trainer_kwargs(
673
673
mesh_rules = (
674
674
# TPU V5e maximum per device batch is 1.
675
675
# with all activation offloading, HBM usage: 14.6GB/chip.
676
+ # tpu-v5e-32-1
677
+ (
678
+ "tpu-v5litepod-32-1" ,
679
+ ChainConfigModifier .default_config ().set (
680
+ config_modifiers = [
681
+ MeshShapeModifier .default_config ().set (
682
+ mesh_shape = mesh_shape_from_axes (data = - 1 , fsdp = 32 )
683
+ ),
684
+ RematSpecModifier .default_config ().set (
685
+ remat_policies = {
686
+ "model.decoder.transformer.layer" : RematSpec (
687
+ prevent_cse = False ,
688
+ policy = offload_dots_saveable_policy ,
689
+ ),
690
+ }
691
+ ),
692
+ ],
693
+ ),
694
+ ),
676
695
# TODO(kelvin-zou): Fix the env issue for internal use cases.
677
696
# tpu-v5e-256-4. step time: 14.3736s (59.87% MFU).
678
697
(
You can’t perform that action at this time.
0 commit comments