@@ -75,7 +75,7 @@ def train_loop(config, recorder, state=None):
7575 config , model , mesh , state , state_mesh_shardings , train_step , eval_step , eval_data_iterator
7676 )
7777
78- with mesh , nn_partitioning .axis_rules (config .logical_axis_rules ):
78+ with jax . set_mesh ( mesh ) , nn_partitioning .axis_rules (config .logical_axis_rules ):
7979 shaped_batch = maxtext_utils .get_shaped_batch (config )
8080 compiled = p_train_step .lower (state , shaped_batch , init_rng ).compile ()
8181 compiled_stats = compiled .memory_analysis ()
@@ -99,7 +99,7 @@ def train_loop(config, recorder, state=None):
9999 # pylint: disable=not-callable
100100 nextrng = jax .jit (jax .random .fold_in )(init_rng , step )
101101 with maybe_record_goodput (recorder , GoodputEvent .STEP , step ):
102- with mesh , nn_partitioning .axis_rules (config .logical_axis_rules ):
102+ with jax . set_mesh ( mesh ) , nn_partitioning .axis_rules (config .logical_axis_rules ):
103103 state , metrics = p_train_step (state , example_batch , nextrng )
104104
105105 step_time_delta = datetime .datetime .now () - last_step_completion
@@ -124,7 +124,7 @@ def train_loop(config, recorder, state=None):
124124 for eval_batch in eval_data_iterator :
125125 if config .eval_steps > 0 and eval_step_count >= config .eval_steps :
126126 break
127- with mesh , nn_partitioning .axis_rules (config .logical_axis_rules ):
127+ with jax . set_mesh ( mesh ) , nn_partitioning .axis_rules (config .logical_axis_rules ):
128128 eval_metrics = p_eval_step (state , eval_batch , nextrng )
129129 metric_logger .record_eval_metrics (step , metrics = eval_metrics )
130130 max_logging .log (f"Completed eval step { eval_step_count } " )
0 commit comments