We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 168e41e commit 1af5203Copy full SHA for 1af5203
axlearn/common/launch_trainer.py
@@ -163,6 +163,8 @@ def run_trainer(trainer_config: SpmdTrainer.Config) -> Any:
163
ten_minutes = 10 * 60
164
elastic_manager.wait_for_slices(timeout=ten_minutes)
165
else:
166
+ trainer: SpmdTrainer = trainer_config.instantiate(parent=None)
167
+ prng_key = jax.random.PRNGKey(seed=FLAGS.trainer_prng_seed)
168
output = trainer.run(prng_key)
169
170
measurement.record_event(measurement.Event.END_JOB)
0 commit comments