We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a7e0f51 commit bd04c4dCopy full SHA for bd04c4d
axlearn/experiments/text/gpt/fuji.py
@@ -252,8 +252,7 @@ def get_trainer_kwargs(
252
tokens_per_batch = TOKENS_PER_BATCH[version]
253
max_step = TOTAL_TOKENS[version][model_size] // tokens_per_batch
254
max_sequence_length = MAX_SEQUENCE_LENGTH[version]
255
- # train_batch_size = tokens_per_batch // max_sequence_length
256
- train_batch_size = len(jax.devices())
+ train_batch_size = tokens_per_batch // max_sequence_length
257
258
# Whether to use grouped query attention.
259
num_kv_heads = None
0 commit comments