Skip to content

Commit 489a775

Browse files
committed
Adding a model config for 7B model and v5litepod-32-2
Changing the batch size to match the chip count
1 parent f9fb40f commit 489a775

File tree

1 file changed

+21
-1
lines changed
  • axlearn/experiments/text/gpt

1 file changed

+21
-1
lines changed

axlearn/experiments/text/gpt/fuji.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import itertools
1616
from typing import Any, List, NamedTuple, Optional, Union
1717

18+
import jax
1819
from jax.ad_checkpoint import checkpoint_policies as jax_remat_policies
1920

2021
from axlearn.common import causal_lm, config
@@ -248,7 +249,8 @@ def get_trainer_kwargs(
248249
tokens_per_batch = TOKENS_PER_BATCH[version]
249250
max_step = TOTAL_TOKENS[version][model_size] // tokens_per_batch
250251
max_sequence_length = MAX_SEQUENCE_LENGTH[version]
251-
train_batch_size = tokens_per_batch // max_sequence_length
252+
# train_batch_size = tokens_per_batch // max_sequence_length
253+
train_batch_size = len(jax.devices())
252254

253255
# Whether to use grouped query attention.
254256
num_kv_heads = None
@@ -392,6 +394,24 @@ def get_trainer_kwargs(
392394
# tpu-v4-(1024|2048).
393395
("tpu-v4-(1024|2048)", mesh_shape_from_axes(data=-1, fsdp=16)),
394396
# tpu-v5e.
397+
(
398+
"tpu-v5litepod-32-2",
399+
ChainConfigModifier.default_config().set(
400+
config_modifiers=[
401+
MeshShapeModifier.default_config().set(
402+
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8)
403+
),
404+
RematSpecModifier.default_config().set(
405+
remat_policies={
406+
"model.decoder.transformer.layer": RematSpec(
407+
prevent_cse=False,
408+
policy=offload_dots_saveable_policy,
409+
),
410+
}
411+
),
412+
],
413+
),
414+
),
395415
(
396416
"tpu-v5litepod-256",
397417
ChainConfigModifier.default_config().set(

0 commit comments

Comments
 (0)