Skip to content

Commit bd04c4d

Browse files
committed
Remove the batch_size=len(jax.devices()) workaround as it is not needed after the JAX patch
1 parent a7e0f51 commit bd04c4d

File tree

1 file changed

+1
-2
lines changed
  • axlearn/experiments/text/gpt

1 file changed

+1
-2
lines changed

axlearn/experiments/text/gpt/fuji.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,7 @@ def get_trainer_kwargs(
252252
tokens_per_batch = TOKENS_PER_BATCH[version]
253253
max_step = TOTAL_TOKENS[version][model_size] // tokens_per_batch
254254
max_sequence_length = MAX_SEQUENCE_LENGTH[version]
255-
# train_batch_size = tokens_per_batch // max_sequence_length
256-
train_batch_size = len(jax.devices())
255+
train_batch_size = tokens_per_batch // max_sequence_length
257256

258257
# Whether to use grouped query attention.
259258
num_kv_heads = None

0 commit comments

Comments
 (0)