Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/MaxText/elastic_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def train_loop(config, elastic_manager, recorder, state=None):
) = setup_train_loop(config, recorder)

p_train_step, _ = train_utils.jit_train_and_eval_step(config, model, mesh, state, state_mesh_shardings, train_step)
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
shaped_batch = maxtext_utils.get_shaped_batch(config)
compiled = p_train_step.lower(state, shaped_batch, init_rng).compile()
compiled_stats = compiled.memory_analysis()
Expand Down
4 changes: 2 additions & 2 deletions src/MaxText/experimental/rl/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,7 @@ def generation_worker_fn(
continue
train_rng, rng = random.split(init_rng)
example_batch = jax.device_put(example_batch, data_sharding)
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
state, metrics = p_train_step(state, example_batch, train_rng)
with jax.profiler.StepTraceAnnotation("transfer data", step_num=step):
if step != 0 and step % config.inference_rollouts == 0:
Expand Down Expand Up @@ -862,7 +862,7 @@ def generation_worker_fn(
for eval_batch in eval_data_iterator:
if 0 < config.eval_steps <= eval_step_count:
break
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
eval_metrics = p_eval_step(state, eval_batch, rng)
metric_logger.record_eval_metrics(step, metrics=eval_metrics)
max_logging.log(f"Completed eval step {eval_step_count}")
Expand Down
6 changes: 3 additions & 3 deletions src/MaxText/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,7 @@ def move(path, x):

unboxed_abstract_sharded_state = max_utils.unbox_logicallypartioned(abstract_sharded_state)
# Initialization
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
state_mesh_annotations = nn.logical_to_mesh(state_logical_annotations)
return (
unboxed_abstract_sharded_state,
Expand Down Expand Up @@ -969,7 +969,7 @@ def init_kv_cache(model, config):
init_kv_cache_partial = functools.partial(init_kv_cache, model, config)
abstract_state = jax.eval_shape(init_kv_cache_partial)
state_logical_annotations = nn.get_partition_spec(abstract_state)
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
state_mesh_annotations = nn.logical_to_mesh(state_logical_annotations)
return state_mesh_annotations

Expand Down Expand Up @@ -998,7 +998,7 @@ def init_kv_cache(model, config):
init_kv_cache_partial = functools.partial(init_kv_cache, model, config)
abstract_state = jax.eval_shape(init_kv_cache_partial)
state_logical_annotations = nn.get_partition_spec(abstract_state)
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
state_mesh_annotations = nn.logical_to_mesh(state_logical_annotations)
return state_mesh_annotations

Expand Down
2 changes: 1 addition & 1 deletion src/MaxText/model_creation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def create_sharded_state():
model = _create_model_partial()
return nnx.state(model)

with mesh:
with jax.set_mesh(mesh):
# Create the model with sharded parameters.
with nn.logical_axis_rules(config.logical_axis_rules):
sharded_state = create_sharded_state()
Expand Down
2 changes: 1 addition & 1 deletion src/MaxText/rl/train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def get_maxtext_model(config, devices=None):
# load_parameters_path=/path/to/your/output/directory/0/items
"""
model, mesh = model_creation_utils.create_nnx_model(config, devices=devices)
with mesh:
with jax.set_mesh(mesh):
tunix_model = TunixMaxTextAdapter(base_model=model)
tunix_model.config = None
return tunix_model, mesh
Expand Down
2 changes: 1 addition & 1 deletion src/MaxText/sft/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def setup_trainer_state(mt_config, goodput_recorder=None):

def train_model(mt_config, trainer, mesh):
"""Runs the SFT training loop in Tunix."""
with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules):
with jax.set_mesh(mesh), nn_partitioning.axis_rules(mt_config.logical_axis_rules):
trainer.train(trainer.data_hooks.train_data_iterator, trainer.data_hooks.eval_data_iterator)
return trainer

Expand Down
6 changes: 3 additions & 3 deletions src/MaxText/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def train_loop(config, recorder, state=None):
config, model, mesh, state, state_mesh_shardings, train_step, eval_step, eval_data_iterator
)

with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
shaped_batch = maxtext_utils.get_shaped_batch(config)
compiled = p_train_step.lower(state, shaped_batch, init_rng).compile()
compiled_stats = compiled.memory_analysis()
Expand All @@ -99,7 +99,7 @@ def train_loop(config, recorder, state=None):
# pylint: disable=not-callable
nextrng = jax.jit(jax.random.fold_in)(init_rng, step)
with maybe_record_goodput(recorder, GoodputEvent.STEP, step):
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
state, metrics = p_train_step(state, example_batch, nextrng)

step_time_delta = datetime.datetime.now() - last_step_completion
Expand All @@ -124,7 +124,7 @@ def train_loop(config, recorder, state=None):
for eval_batch in eval_data_iterator:
if config.eval_steps > 0 and eval_step_count >= config.eval_steps:
break
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
eval_metrics = p_eval_step(state, eval_batch, nextrng)
metric_logger.record_eval_metrics(step, metrics=eval_metrics)
max_logging.log(f"Completed eval step {eval_step_count}")
Expand Down
6 changes: 3 additions & 3 deletions src/MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def train_loop(config, recorder, state=None):
params_shardings,
)

with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
shaped_batch = maxtext_utils.get_shaped_batch(config)
if config.shard_optimizer_over_data:
state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode)
Expand Down Expand Up @@ -434,7 +434,7 @@ def train_loop(config, recorder, state=None):
# pylint: disable=not-callable
nextrng = jax.jit(jax.random.fold_in)(init_rng, step)
with maybe_record_goodput(recorder, GoodputEvent.STEP, step):
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
if config.shard_optimizer_over_data:
state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode)
state, metrics = p_train_step(state, example_batch, nextrng)
Expand Down Expand Up @@ -466,7 +466,7 @@ def train_loop(config, recorder, state=None):
for eval_batch in eval_data_iterator:
if config.eval_steps > 0 and eval_step_count >= config.eval_steps:
break
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
eval_metrics = p_eval_step(state, eval_batch, nextrng)
metric_logger.record_eval_metrics(step, metrics=eval_metrics)
max_logging.log(f"Completed eval step {eval_step_count}")
Expand Down
2 changes: 1 addition & 1 deletion src/MaxText/train_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def jit_and_compile(
logical_axis_rules,
):
"""Jit, lower, and compile func."""
with mesh, logical_axis_rules:
with jax.set_mesh(mesh), logical_axis_rules:
jitted = jax.jit(
func,
in_shardings=in_shardings,
Expand Down
2 changes: 1 addition & 1 deletion src/MaxText/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def setup_train_loop(config, recorder, devices=None):
)

# Apply reordering wrapper to data iterators if context parallelism is enabled
with mesh:
with jax.set_mesh(mesh):
if context_parallel_size > 1 and config.context_parallel_load_balance:
data_iterator = map(maxtext_utils.get_reorder_callable(context_parallel_size, config.shard_mode), data_iterator)
if eval_data_iterator:
Expand Down
6 changes: 3 additions & 3 deletions tests/maxtext_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def test_multi_axis_sharding_pass(self):
devices = np.array(jax.devices()).reshape((4, 1, 1, 1, 1))
mesh = Mesh(devices, self.mesh_axes)

with mesh:
with jax.set_mesh(mesh):
# Shard across multiple axes, including the valid 'fsdp' axis.
pspec = PartitionSpec(("fsdp", "sequence"), "stage", ("tensor"), None)
params = {"complex_layer": jax.device_put(jnp.ones((8, 8, 2, 2)), NamedSharding(mesh, pspec))}
Expand All @@ -389,7 +389,7 @@ def test_multi_axis_not_sharded_fails(self):
"""
devices = np.array(jax.devices()).reshape((4, 1, 1, 1, 1))
mesh = Mesh(devices, self.mesh_axes)
with mesh:
with jax.set_mesh(mesh):
pspec = PartitionSpec(("sequence", "context"), "stage", "tensor", None)
params = {"complex_layer": jax.device_put(jnp.ones((8, 8, 2, 2)), NamedSharding(mesh, pspec))}

Expand All @@ -402,7 +402,7 @@ def test_multi_axis_mixed_sharding_fails(self):
"""
devices = np.array(jax.devices()).reshape((4, 1, 1, 1, 1))
mesh = Mesh(devices, self.mesh_axes)
with mesh:
with jax.set_mesh(mesh):
sharded_pspec = PartitionSpec(("fsdp", "sequence"), "stage", ("tensor"), None)
sharded_param = jax.device_put(jnp.ones((8, 8, 2, 2)), NamedSharding(mesh, sharded_pspec))
unsharded_param = jnp.ones((8, 8, 2, 2))
Expand Down
Loading