Skip to content

Commit 5af35b8

Browse files
committed
Adding the exclusive topology by hostname for the head job
Adding a model config for v5litepod-32 Changing the batch size to match the chip count and the checkpoint step interval to avoid any checkpoints for testing
1 parent 4f3c12f commit 5af35b8

File tree

2 files changed

+25
-0
lines changed

2 files changed

+25
-0
lines changed

axlearn/cloud/gcp/pathways_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,11 @@ def _build_pathways_head_job(self):
429429
annotations = _LoadBalancer(
430430
jobset_name=cfg.name, replicated_job_name=_PATHWAYS_HEAD_REPLICATED_JOB_NAME
431431
).metadata
432+
annotations.update(
433+
{
434+
"alpha.jobset.sigs.k8s.io/exclusive-topology": "kubernetes.io/hostname",
435+
}
436+
)
432437
spec = dict(
433438
parallelism=1,
434439
completions=1,

axlearn/experiments/text/gpt/fuji.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def get_trainer_kwargs(
249249
max_step = TOTAL_TOKENS[version][model_size] // tokens_per_batch
250250
max_sequence_length = MAX_SEQUENCE_LENGTH[version]
251251
train_batch_size = tokens_per_batch // max_sequence_length
252+
train_batch_size = 32
252253

253254
# Whether to use grouped query attention.
254255
num_kv_heads = None
@@ -392,6 +393,25 @@ def get_trainer_kwargs(
392393
# tpu-v4-(1024|2048).
393394
("tpu-v4-(1024|2048)", mesh_shape_from_axes(data=-1, fsdp=16)),
394395
# tpu-v5e.
396+
(
397+
"tpu-v5litepod-32",
398+
ChainConfigModifier.default_config().set(
399+
config_modifiers=[
400+
MeshShapeModifier.default_config().set(
401+
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=32)
402+
),
403+
RematSpecModifier.default_config().set(
404+
remat_policies={
405+
"model.decoder.transformer.layer": RematSpec(
406+
prevent_cse=False,
407+
policy=offload_dots_saveable_policy,
408+
),
409+
}
410+
),
411+
GradientAccumulationModifier.default_config().set(grad_acc_steps=4),
412+
],
413+
),
414+
),
395415
(
396416
"tpu-v5litepod-256",
397417
ChainConfigModifier.default_config().set(

0 commit comments

Comments
 (0)