Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/MaxText/elastic_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def train_loop(config, elastic_manager, recorder, state=None):
) = setup_train_loop(config, recorder)

p_train_step, _ = train_utils.jit_train_and_eval_step(config, model, mesh, state, state_mesh_shardings, train_step)
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
shaped_batch = maxtext_utils.get_shaped_batch(config)
compiled = p_train_step.lower(state, shaped_batch, init_rng).compile()
compiled_stats = compiled.memory_analysis()
Expand Down
4 changes: 2 additions & 2 deletions src/MaxText/experimental/rl/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,7 @@ def generation_worker_fn(
continue
train_rng, rng = random.split(init_rng)
example_batch = jax.device_put(example_batch, data_sharding)
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
state, metrics = p_train_step(state, example_batch, train_rng)
with jax.profiler.StepTraceAnnotation("transfer data", step_num=step):
if step != 0 and step % config.inference_rollouts == 0:
Expand Down Expand Up @@ -862,7 +862,7 @@ def generation_worker_fn(
for eval_batch in eval_data_iterator:
if 0 < config.eval_steps <= eval_step_count:
break
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
eval_metrics = p_eval_step(state, eval_batch, rng)
metric_logger.record_eval_metrics(step, metrics=eval_metrics)
max_logging.log(f"Completed eval step {eval_step_count}")
Expand Down
6 changes: 3 additions & 3 deletions src/MaxText/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,7 @@ def move(path, x):

unboxed_abstract_sharded_state = max_utils.unbox_logicallypartioned(abstract_sharded_state)
# Initialization
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
state_mesh_annotations = nn.logical_to_mesh(state_logical_annotations)
return (
unboxed_abstract_sharded_state,
Expand Down Expand Up @@ -969,7 +969,7 @@ def init_kv_cache(model, config):
init_kv_cache_partial = functools.partial(init_kv_cache, model, config)
abstract_state = jax.eval_shape(init_kv_cache_partial)
state_logical_annotations = nn.get_partition_spec(abstract_state)
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
state_mesh_annotations = nn.logical_to_mesh(state_logical_annotations)
return state_mesh_annotations

Expand Down Expand Up @@ -998,7 +998,7 @@ def init_kv_cache(model, config):
init_kv_cache_partial = functools.partial(init_kv_cache, model, config)
abstract_state = jax.eval_shape(init_kv_cache_partial)
state_logical_annotations = nn.get_partition_spec(abstract_state)
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
state_mesh_annotations = nn.logical_to_mesh(state_logical_annotations)
return state_mesh_annotations

Expand Down
2 changes: 1 addition & 1 deletion src/MaxText/model_creation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def create_sharded_state():
model = _create_model_partial()
return nnx.state(model)

with mesh:
with jax.set_mesh(mesh):
# Create the model with sharded parameters.
with nn.logical_axis_rules(config.logical_axis_rules):
sharded_state = create_sharded_state()
Expand Down
2 changes: 1 addition & 1 deletion src/MaxText/rl/train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def get_maxtext_model(config, devices=None):
# load_parameters_path=/path/to/your/output/directory/0/items
"""
model, mesh = model_creation_utils.create_nnx_model(config, devices=devices)
with mesh:
with jax.set_mesh(mesh):
tunix_model = TunixMaxTextAdapter(base_model=model)
tunix_model.config = None
return tunix_model, mesh
Expand Down
2 changes: 1 addition & 1 deletion src/MaxText/sft/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def setup_trainer_state(mt_config, goodput_recorder=None):

def train_model(mt_config, trainer, mesh):
"""Runs the SFT training loop in Tunix."""
with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules):
with jax.set_mesh(mesh), nn_partitioning.axis_rules(mt_config.logical_axis_rules):
trainer.train(trainer.data_hooks.train_data_iterator, trainer.data_hooks.eval_data_iterator)
return trainer

Expand Down
6 changes: 3 additions & 3 deletions src/MaxText/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def train_loop(config, recorder, state=None):
config, model, mesh, state, state_mesh_shardings, train_step, eval_step, eval_data_iterator
)

with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
shaped_batch = maxtext_utils.get_shaped_batch(config)
compiled = p_train_step.lower(state, shaped_batch, init_rng).compile()
compiled_stats = compiled.memory_analysis()
Expand All @@ -99,7 +99,7 @@ def train_loop(config, recorder, state=None):
# pylint: disable=not-callable
nextrng = jax.jit(jax.random.fold_in)(init_rng, step)
with maybe_record_goodput(recorder, GoodputEvent.STEP, step):
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
state, metrics = p_train_step(state, example_batch, nextrng)

step_time_delta = datetime.datetime.now() - last_step_completion
Expand All @@ -124,7 +124,7 @@ def train_loop(config, recorder, state=None):
for eval_batch in eval_data_iterator:
if config.eval_steps > 0 and eval_step_count >= config.eval_steps:
break
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
eval_metrics = p_eval_step(state, eval_batch, nextrng)
metric_logger.record_eval_metrics(step, metrics=eval_metrics)
max_logging.log(f"Completed eval step {eval_step_count}")
Expand Down
6 changes: 3 additions & 3 deletions src/MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def train_loop(config, recorder, state=None):
params_shardings,
)

with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
shaped_batch = maxtext_utils.get_shaped_batch(config)
if config.shard_optimizer_over_data:
state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode)
Expand Down Expand Up @@ -434,7 +434,7 @@ def train_loop(config, recorder, state=None):
# pylint: disable=not-callable
nextrng = jax.jit(jax.random.fold_in)(init_rng, step)
with maybe_record_goodput(recorder, GoodputEvent.STEP, step):
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
if config.shard_optimizer_over_data:
state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode)
state, metrics = p_train_step(state, example_batch, nextrng)
Expand Down Expand Up @@ -466,7 +466,7 @@ def train_loop(config, recorder, state=None):
for eval_batch in eval_data_iterator:
if config.eval_steps > 0 and eval_step_count >= config.eval_steps:
break
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
eval_metrics = p_eval_step(state, eval_batch, nextrng)
metric_logger.record_eval_metrics(step, metrics=eval_metrics)
max_logging.log(f"Completed eval step {eval_step_count}")
Expand Down
2 changes: 1 addition & 1 deletion src/MaxText/train_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def jit_and_compile(
logical_axis_rules,
):
"""Jit, lower, and compile func."""
with mesh, logical_axis_rules:
with jax.set_mesh(mesh), logical_axis_rules:
jitted = jax.jit(
func,
in_shardings=in_shardings,
Expand Down
2 changes: 1 addition & 1 deletion src/MaxText/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def setup_train_loop(config, recorder, devices=None):
)

# Apply reordering wrapper to data iterators if context parallelism is enabled
with mesh:
with jax.set_mesh(mesh):
if context_parallel_size > 1 and config.context_parallel_load_balance:
data_iterator = map(maxtext_utils.get_reorder_callable(context_parallel_size, config.shard_mode), data_iterator)
if eval_data_iterator:
Expand Down
105 changes: 48 additions & 57 deletions tests/maxtext_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,52 +318,47 @@ def test_fully_sharded_2d(self):
"""
Tests that a 2D tensor fully sharded across both mesh axes passes the assertion.
"""
# Activate the mesh context.
with self.mesh:
# Define a sharding spec that shards the first tensor dimension by the 'fsdp' mesh axis
# and the second dimension by the 'tensor' mesh axis.
pspec = PartitionSpec("fsdp", "tensor")
# Create a parameter and apply the sharding, ensuring it's distributed across all devices.
params = {"layer1": jax.device_put(jnp.ones((8, 8)), NamedSharding(self.mesh, pspec))}

# Assert that the parameters are sufficiently sharded; this should pass with no error.
assert_params_sufficiently_sharded(params, self.mesh, tolerance=0.1)
# Define a sharding spec that shards the first tensor dimension by the 'fsdp' mesh axis
# and the second dimension by the 'tensor' mesh axis.
pspec = PartitionSpec("fsdp", "tensor")
# Create a parameter and apply the sharding, ensuring it's distributed across all devices.
params = {"layer1": jax.device_put(jnp.ones((8, 8)), NamedSharding(self.mesh, pspec))}

# Assert that the parameters are sufficiently sharded; this should pass with no error.
assert_params_sufficiently_sharded(params, self.mesh, tolerance=0.1)

def test_unsharded_fails(self):
"""
Tests that a completely unsharded (fully replicated) parameter fails the assertion.
"""
with self.mesh:
# Create a parameter without any sharding specification. It will be replicated on all devices.
params = {"layer1": jnp.ones((8, 8))}
# Create a parameter without any sharding specification. It will be replicated on all devices.
params = {"layer1": jnp.ones((8, 8))}

# Expect an AssertionError because 100% of params are unsharded, exceeding the 10% tolerance.
with self.assertRaises(AssertionError):
assert_params_sufficiently_sharded(params, self.mesh, tolerance=0.1)
# Expect an AssertionError because 100% of params are unsharded, exceeding the 10% tolerance.
with self.assertRaises(AssertionError):
assert_params_sufficiently_sharded(params, self.mesh, tolerance=0.1)

def test_mixed_sharding_fails(self):
"""
Tests that a mix of sharded and unsharded parameters fails when the unsharded
portion exceeds the tolerance.
"""
with self.mesh:
sharded_param = jax.device_put(jnp.ones((8, 8)), NamedSharding(self.mesh, PartitionSpec("fsdp", "tensor")))
unsharded_param = jnp.ones((8, 8))
params = {"layer1": sharded_param, "layer2": unsharded_param}
sharded_param = jax.device_put(jnp.ones((8, 8)), NamedSharding(self.mesh, PartitionSpec("fsdp", "tensor")))
unsharded_param = jnp.ones((8, 8))
params = {"layer1": sharded_param, "layer2": unsharded_param}

with self.assertRaises(AssertionError):
assert_params_sufficiently_sharded(params, self.mesh, tolerance=0.5)
with self.assertRaises(AssertionError):
assert_params_sufficiently_sharded(params, self.mesh, tolerance=0.5)

def test_3d_tensor_sharded_on_fsdp_axis(self):
"""
Tests that a 3D tensor sharded only on a valid target axis ('fsdp') should fail.
"""
with self.mesh:
pspec = PartitionSpec("fsdp", None, None)
params = {"conv3d_layer": jax.device_put(jnp.ones((8, 4, 4)), NamedSharding(self.mesh, pspec))}
pspec = PartitionSpec("fsdp", None, None)
params = {"conv3d_layer": jax.device_put(jnp.ones((8, 4, 4)), NamedSharding(self.mesh, pspec))}

with self.assertRaises(AssertionError):
assert_params_sufficiently_sharded(params, self.mesh, tolerance=0.2)
with self.assertRaises(AssertionError):
assert_params_sufficiently_sharded(params, self.mesh, tolerance=0.2)

def test_multi_axis_sharding_pass(self):
"""
Expand All @@ -374,13 +369,12 @@ def test_multi_axis_sharding_pass(self):
devices = np.array(jax.devices()).reshape((4, 1, 1, 1, 1))
mesh = Mesh(devices, self.mesh_axes)

with mesh:
# Shard across multiple axes, including the valid 'fsdp' axis.
pspec = PartitionSpec(("fsdp", "sequence"), "stage", ("tensor"), None)
params = {"complex_layer": jax.device_put(jnp.ones((8, 8, 2, 2)), NamedSharding(mesh, pspec))}
# Shard across multiple axes, including the valid 'fsdp' axis.
pspec = PartitionSpec(("fsdp", "sequence"), "stage", ("tensor"), None)
params = {"complex_layer": jax.device_put(jnp.ones((8, 8, 2, 2)), NamedSharding(mesh, pspec))}

# This should pass because 'fsdp' is a valid sharding axis being used.
assert_params_sufficiently_sharded(params, mesh, tolerance=0.05)
# This should pass because 'fsdp' is a valid sharding axis being used.
assert_params_sufficiently_sharded(params, mesh, tolerance=0.05)

def test_multi_axis_not_sharded_fails(self):
"""
Expand All @@ -389,30 +383,28 @@ def test_multi_axis_not_sharded_fails(self):
"""
devices = np.array(jax.devices()).reshape((4, 1, 1, 1, 1))
mesh = Mesh(devices, self.mesh_axes)
with mesh:
pspec = PartitionSpec(("sequence", "context"), "stage", "tensor", None)
params = {"complex_layer": jax.device_put(jnp.ones((8, 8, 2, 2)), NamedSharding(mesh, pspec))}
pspec = PartitionSpec(("sequence", "context"), "stage", "tensor", None)
params = {"complex_layer": jax.device_put(jnp.ones((8, 8, 2, 2)), NamedSharding(mesh, pspec))}

with self.assertRaises(AssertionError):
assert_params_sufficiently_sharded(params, mesh, tolerance=0.05)
with self.assertRaises(AssertionError):
assert_params_sufficiently_sharded(params, mesh, tolerance=0.05)

def test_multi_axis_mixed_sharding_fails(self):
"""
Tests that a mix of sharded (correctly) and unsharded tensors on a complex mesh fails.
"""
devices = np.array(jax.devices()).reshape((4, 1, 1, 1, 1))
mesh = Mesh(devices, self.mesh_axes)
with mesh:
sharded_pspec = PartitionSpec(("fsdp", "sequence"), "stage", ("tensor"), None)
sharded_param = jax.device_put(jnp.ones((8, 8, 2, 2)), NamedSharding(mesh, sharded_pspec))
unsharded_param = jnp.ones((8, 8, 2, 2))
params = {
"sharded_layer": sharded_param,
"unsharded_layer": unsharded_param,
}
sharded_pspec = PartitionSpec(("fsdp", "sequence"), "stage", ("tensor"), None)
sharded_param = jax.device_put(jnp.ones((8, 8, 2, 2)), NamedSharding(mesh, sharded_pspec))
unsharded_param = jnp.ones((8, 8, 2, 2))
params = {
"sharded_layer": sharded_param,
"unsharded_layer": unsharded_param,
}

with self.assertRaises(AssertionError):
assert_params_sufficiently_sharded(params, mesh, tolerance=0.5)
with self.assertRaises(AssertionError):
assert_params_sufficiently_sharded(params, mesh, tolerance=0.5)


class TestAssert_Formatted_sharding_annotations(unittest.TestCase):
Expand All @@ -435,15 +427,14 @@ def test_multi_axis_mixed_formating(self):
"""
Tests a mix of sharded and unsharded tensors on a complex mesh fails.
"""
with self.mesh:
sharded_pspec = PartitionSpec(("fsdp", "sequence"), "stage", ("tensor"), None)
sharded_param = jax.device_put(jnp.ones((8, 8, 2, 2)), NamedSharding(self.mesh, sharded_pspec))
unsharded_param = jnp.ones((8, 8, 2, 2))
params = {
"sharded_layer": sharded_param,
"unsharded_layer": unsharded_param,
}
self.assertIsNotNone(get_formatted_sharding_annotations(params, self.mesh))
sharded_pspec = PartitionSpec(("fsdp", "sequence"), "stage", ("tensor"), None)
sharded_param = jax.device_put(jnp.ones((8, 8, 2, 2)), NamedSharding(self.mesh, sharded_pspec))
unsharded_param = jnp.ones((8, 8, 2, 2))
params = {
"sharded_layer": sharded_param,
"unsharded_layer": unsharded_param,
}
self.assertIsNotNone(get_formatted_sharding_annotations(params, self.mesh))


class TestPromptLogprobsFromPrefill(unittest.TestCase):
Expand Down
Loading