Skip to content

Commit e7cd77a

Browse files
authored
Reset the patch for axlearn (#1945)
1 parent 537402b commit e7cd77a

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

.github/container/fuji-train-perf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,8 @@ def main(parsed_args):
278278
f"======================\n"
279279
)
280280

281+
# Call AXLearn Jax setup
282+
launch.setup()
281283
# Build the model config
282284
config_fn = c4_trainer.named_trainer_configs()[config_name]
283285
trainer_config: SpmdTrainer.Config = config_for_function(config_fn).fn()
@@ -320,8 +322,6 @@ def main(parsed_args):
320322
if trace_steps is not None:
321323
trainer_config.start_trace_steps = trace_steps
322324

323-
# Call AXLearn Jax setup
324-
launch.setup()
325325
# Setup the config
326326
trainer_config.set(
327327
recorder=config_for_function(lambda: measurement.global_recorder)

0 commit comments

Comments
 (0)