diff --git a/src/MaxText/train.py b/src/MaxText/train.py index fb5fcebdb..bebd737e5 100644 --- a/src/MaxText/train.py +++ b/src/MaxText/train.py @@ -412,7 +412,7 @@ def train_loop(config, recorder, state=None): max_utils.print_compiled_memory_stats(compiled_stats) start_step = get_first_step(state) # this is the start_step for training - prof = profiler.Profiler(config, offset_step=start_step) + prof = profiler.Profiler(config, offset_step=start_step, ) metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule) # Write train config params, num model params, and XLA flags to tensorboard