We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 537402b commit e7cd77aCopy full SHA for e7cd77a
.github/container/fuji-train-perf.py
@@ -278,6 +278,8 @@ def main(parsed_args):
278
f"======================\n"
279
)
280
281
+ # Call AXLearn Jax setup
282
+ launch.setup()
283
# Build the model config
284
config_fn = c4_trainer.named_trainer_configs()[config_name]
285
trainer_config: SpmdTrainer.Config = config_for_function(config_fn).fn()
@@ -320,8 +322,6 @@ def main(parsed_args):
320
322
if trace_steps is not None:
321
323
trainer_config.start_trace_steps = trace_steps
324
- # Call AXLearn Jax setup
- launch.setup()
325
# Setup the config
326
trainer_config.set(
327
recorder=config_for_function(lambda: measurement.global_recorder)
0 commit comments