Skip to content

Commit 722d7e2

Browse files
committed
update with mesh
1 parent 094b41d commit 722d7e2

File tree

11 files changed

+65
-74
lines changed

11 files changed

+65
-74
lines changed

src/MaxText/elastic_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def train_loop(config, elastic_manager, recorder, state=None):
186186
) = setup_train_loop(config, recorder)
187187

188188
p_train_step, _ = train_utils.jit_train_and_eval_step(config, model, mesh, state, state_mesh_shardings, train_step)
189-
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
189+
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
190190
shaped_batch = maxtext_utils.get_shaped_batch(config)
191191
compiled = p_train_step.lower(state, shaped_batch, init_rng).compile()
192192
compiled_stats = compiled.memory_analysis()

src/MaxText/experimental/rl/grpo_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -824,7 +824,7 @@ def generation_worker_fn(
824824
continue
825825
train_rng, rng = random.split(init_rng)
826826
example_batch = jax.device_put(example_batch, data_sharding)
827-
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
827+
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
828828
state, metrics = p_train_step(state, example_batch, train_rng)
829829
with jax.profiler.StepTraceAnnotation("transfer data", step_num=step):
830830
if step != 0 and step % config.inference_rollouts == 0:
@@ -862,7 +862,7 @@ def generation_worker_fn(
862862
for eval_batch in eval_data_iterator:
863863
if 0 < config.eval_steps <= eval_step_count:
864864
break
865-
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
865+
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
866866
eval_metrics = p_eval_step(state, eval_batch, rng)
867867
metric_logger.record_eval_metrics(step, metrics=eval_metrics)
868868
max_logging.log(f"Completed eval step {eval_step_count}")

src/MaxText/maxtext_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -933,7 +933,7 @@ def move(path, x):
933933

934934
unboxed_abstract_sharded_state = max_utils.unbox_logicallypartioned(abstract_sharded_state)
935935
# Initialization
936-
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
936+
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
937937
state_mesh_annotations = nn.logical_to_mesh(state_logical_annotations)
938938
return (
939939
unboxed_abstract_sharded_state,
@@ -969,7 +969,7 @@ def init_kv_cache(model, config):
969969
init_kv_cache_partial = functools.partial(init_kv_cache, model, config)
970970
abstract_state = jax.eval_shape(init_kv_cache_partial)
971971
state_logical_annotations = nn.get_partition_spec(abstract_state)
972-
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
972+
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
973973
state_mesh_annotations = nn.logical_to_mesh(state_logical_annotations)
974974
return state_mesh_annotations
975975

@@ -998,7 +998,7 @@ def init_kv_cache(model, config):
998998
init_kv_cache_partial = functools.partial(init_kv_cache, model, config)
999999
abstract_state = jax.eval_shape(init_kv_cache_partial)
10001000
state_logical_annotations = nn.get_partition_spec(abstract_state)
1001-
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
1001+
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
10021002
state_mesh_annotations = nn.logical_to_mesh(state_logical_annotations)
10031003
return state_mesh_annotations
10041004

src/MaxText/model_creation_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def create_sharded_state():
154154
model = _create_model_partial()
155155
return nnx.state(model)
156156

157-
with mesh:
157+
with jax.set_mesh(mesh):
158158
# Create the model with sharded parameters.
159159
with nn.logical_axis_rules(config.logical_axis_rules):
160160
sharded_state = create_sharded_state()

src/MaxText/rl/train_rl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def get_maxtext_model(config, devices=None):
9292
# load_parameters_path=/path/to/your/output/directory/0/items
9393
"""
9494
model, mesh = model_creation_utils.create_nnx_model(config, devices=devices)
95-
with mesh:
95+
with jax.set_mesh(mesh):
9696
tunix_model = TunixMaxTextAdapter(base_model=model)
9797
tunix_model.config = None
9898
return tunix_model, mesh

src/MaxText/sft/sft_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def setup_trainer_state(mt_config, goodput_recorder=None):
160160

161161
def train_model(mt_config, trainer, mesh):
162162
"""Runs the SFT training loop in Tunix."""
163-
with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules):
163+
with jax.set_mesh(mesh), nn_partitioning.axis_rules(mt_config.logical_axis_rules):
164164
trainer.train(trainer.data_hooks.train_data_iterator, trainer.data_hooks.eval_data_iterator)
165165
return trainer
166166

src/MaxText/sft_trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def train_loop(config, recorder, state=None):
7575
config, model, mesh, state, state_mesh_shardings, train_step, eval_step, eval_data_iterator
7676
)
7777

78-
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
78+
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
7979
shaped_batch = maxtext_utils.get_shaped_batch(config)
8080
compiled = p_train_step.lower(state, shaped_batch, init_rng).compile()
8181
compiled_stats = compiled.memory_analysis()
@@ -99,7 +99,7 @@ def train_loop(config, recorder, state=None):
9999
# pylint: disable=not-callable
100100
nextrng = jax.jit(jax.random.fold_in)(init_rng, step)
101101
with maybe_record_goodput(recorder, GoodputEvent.STEP, step):
102-
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
102+
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
103103
state, metrics = p_train_step(state, example_batch, nextrng)
104104

105105
step_time_delta = datetime.datetime.now() - last_step_completion
@@ -124,7 +124,7 @@ def train_loop(config, recorder, state=None):
124124
for eval_batch in eval_data_iterator:
125125
if config.eval_steps > 0 and eval_step_count >= config.eval_steps:
126126
break
127-
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
127+
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
128128
eval_metrics = p_eval_step(state, eval_batch, nextrng)
129129
metric_logger.record_eval_metrics(step, metrics=eval_metrics)
130130
max_logging.log(f"Completed eval step {eval_step_count}")

src/MaxText/train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ def train_loop(config, recorder, state=None):
402402
params_shardings,
403403
)
404404

405-
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
405+
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
406406
shaped_batch = maxtext_utils.get_shaped_batch(config)
407407
if config.shard_optimizer_over_data:
408408
state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode)
@@ -434,7 +434,7 @@ def train_loop(config, recorder, state=None):
434434
# pylint: disable=not-callable
435435
nextrng = jax.jit(jax.random.fold_in)(init_rng, step)
436436
with maybe_record_goodput(recorder, GoodputEvent.STEP, step):
437-
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
437+
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
438438
if config.shard_optimizer_over_data:
439439
state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode)
440440
state, metrics = p_train_step(state, example_batch, nextrng)
@@ -466,7 +466,7 @@ def train_loop(config, recorder, state=None):
466466
for eval_batch in eval_data_iterator:
467467
if config.eval_steps > 0 and eval_step_count >= config.eval_steps:
468468
break
469-
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
469+
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
470470
eval_metrics = p_eval_step(state, eval_batch, nextrng)
471471
metric_logger.record_eval_metrics(step, metrics=eval_metrics)
472472
max_logging.log(f"Completed eval step {eval_step_count}")

src/MaxText/train_compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def jit_and_compile(
120120
logical_axis_rules,
121121
):
122122
"""Jit, lower, and compile func."""
123-
with mesh, logical_axis_rules:
123+
with jax.set_mesh(mesh), logical_axis_rules:
124124
jitted = jax.jit(
125125
func,
126126
in_shardings=in_shardings,

src/MaxText/train_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def setup_train_loop(config, recorder, devices=None):
193193
)
194194

195195
# Apply reordering wrapper to data iterators if context parallelism is enabled
196-
with mesh:
196+
with jax.set_mesh(mesh):
197197
if context_parallel_size > 1 and config.context_parallel_load_balance:
198198
data_iterator = map(maxtext_utils.get_reorder_callable(context_parallel_size, config.shard_mode), data_iterator)
199199
if eval_data_iterator:

0 commit comments

Comments
 (0)