|
15 | 15 | import itertools
|
16 | 16 | from typing import Any, List, NamedTuple, Optional, Union
|
17 | 17 |
|
| 18 | +import jax |
18 | 19 | from jax.ad_checkpoint import checkpoint_policies as jax_remat_policies
|
19 | 20 |
|
20 | 21 | from axlearn.common import causal_lm, config
|
@@ -248,7 +249,8 @@ def get_trainer_kwargs(
|
248 | 249 | tokens_per_batch = TOKENS_PER_BATCH[version]
|
249 | 250 | max_step = TOTAL_TOKENS[version][model_size] // tokens_per_batch
|
250 | 251 | max_sequence_length = MAX_SEQUENCE_LENGTH[version]
|
251 |
| - train_batch_size = tokens_per_batch // max_sequence_length |
| 252 | + # train_batch_size = tokens_per_batch // max_sequence_length |
| 253 | + train_batch_size = len(jax.devices()) |
252 | 254 |
|
253 | 255 | # Whether to use grouped query attention.
|
254 | 256 | num_kv_heads = None
|
@@ -392,6 +394,24 @@ def get_trainer_kwargs(
|
392 | 394 | # tpu-v4-(1024|2048).
|
393 | 395 | ("tpu-v4-(1024|2048)", mesh_shape_from_axes(data=-1, fsdp=16)),
|
394 | 396 | # tpu-v5e.
|
| 397 | + ( |
| 398 | + "tpu-v5litepod-32-2", |
| 399 | + ChainConfigModifier.default_config().set( |
| 400 | + config_modifiers=[ |
| 401 | + MeshShapeModifier.default_config().set( |
| 402 | + mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8) |
| 403 | + ), |
| 404 | + RematSpecModifier.default_config().set( |
| 405 | + remat_policies={ |
| 406 | + "model.decoder.transformer.layer": RematSpec( |
| 407 | + prevent_cse=False, |
| 408 | + policy=offload_dots_saveable_policy, |
| 409 | + ), |
| 410 | + } |
| 411 | + ), |
| 412 | + ], |
| 413 | + ), |
| 414 | + ), |
395 | 415 | (
|
396 | 416 | "tpu-v5litepod-256",
|
397 | 417 | ChainConfigModifier.default_config().set(
|
|
0 commit comments