From 722d7e258e30ac54d78ac181604d7b9e2ace92c2 Mon Sep 17 00:00:00 2001 From: NuojCheng Date: Wed, 10 Dec 2025 03:18:37 +0000 Subject: [PATCH] update with mesh --- src/MaxText/elastic_train.py | 2 +- src/MaxText/experimental/rl/grpo_trainer.py | 4 +- src/MaxText/maxtext_utils.py | 6 +- src/MaxText/model_creation_utils.py | 2 +- src/MaxText/rl/train_rl.py | 2 +- src/MaxText/sft/sft_trainer.py | 2 +- src/MaxText/sft_trainer.py | 6 +- src/MaxText/train.py | 6 +- src/MaxText/train_compile.py | 2 +- src/MaxText/train_utils.py | 2 +- tests/maxtext_utils_test.py | 105 +++++++++----------- 11 files changed, 65 insertions(+), 74 deletions(-) diff --git a/src/MaxText/elastic_train.py b/src/MaxText/elastic_train.py index 15be83393..e2ee2ec95 100644 --- a/src/MaxText/elastic_train.py +++ b/src/MaxText/elastic_train.py @@ -186,7 +186,7 @@ def train_loop(config, elastic_manager, recorder, state=None): ) = setup_train_loop(config, recorder) p_train_step, _ = train_utils.jit_train_and_eval_step(config, model, mesh, state, state_mesh_shardings, train_step) - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): shaped_batch = maxtext_utils.get_shaped_batch(config) compiled = p_train_step.lower(state, shaped_batch, init_rng).compile() compiled_stats = compiled.memory_analysis() diff --git a/src/MaxText/experimental/rl/grpo_trainer.py b/src/MaxText/experimental/rl/grpo_trainer.py index 440cd190d..926f1e1b6 100644 --- a/src/MaxText/experimental/rl/grpo_trainer.py +++ b/src/MaxText/experimental/rl/grpo_trainer.py @@ -824,7 +824,7 @@ def generation_worker_fn( continue train_rng, rng = random.split(init_rng) example_batch = jax.device_put(example_batch, data_sharding) - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): state, metrics = p_train_step(state, example_batch, train_rng) with jax.profiler.StepTraceAnnotation("transfer data", step_num=step): if step != 0 and step % config.inference_rollouts == 0: @@ -862,7 +862,7 @@ def generation_worker_fn( for eval_batch in eval_data_iterator: if 0 < config.eval_steps <= eval_step_count: break - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): eval_metrics = p_eval_step(state, eval_batch, rng) metric_logger.record_eval_metrics(step, metrics=eval_metrics) max_logging.log(f"Completed eval step {eval_step_count}") diff --git a/src/MaxText/maxtext_utils.py b/src/MaxText/maxtext_utils.py index 30abfbdf9..0929ec775 100644 --- a/src/MaxText/maxtext_utils.py +++ b/src/MaxText/maxtext_utils.py @@ -933,7 +933,7 @@ def move(path, x): unboxed_abstract_sharded_state = max_utils.unbox_logicallypartioned(abstract_sharded_state) # Initialization - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): state_mesh_annotations = nn.logical_to_mesh(state_logical_annotations) return ( unboxed_abstract_sharded_state, @@ -969,7 +969,7 @@ def init_kv_cache(model, config): init_kv_cache_partial = functools.partial(init_kv_cache, model, config) abstract_state = jax.eval_shape(init_kv_cache_partial) state_logical_annotations = nn.get_partition_spec(abstract_state) - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): state_mesh_annotations = nn.logical_to_mesh(state_logical_annotations) return state_mesh_annotations @@ -998,7 +998,7 @@ def init_kv_cache(model, config): init_kv_cache_partial = functools.partial(init_kv_cache, model, config) abstract_state = jax.eval_shape(init_kv_cache_partial) state_logical_annotations = nn.get_partition_spec(abstract_state) - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): state_mesh_annotations = nn.logical_to_mesh(state_logical_annotations) return state_mesh_annotations diff --git a/src/MaxText/model_creation_utils.py b/src/MaxText/model_creation_utils.py index 82211d315..ade242d43 100644 --- a/src/MaxText/model_creation_utils.py +++ b/src/MaxText/model_creation_utils.py @@ -154,7 +154,7 @@ def create_sharded_state(): model = _create_model_partial() return nnx.state(model) - with mesh: + with jax.set_mesh(mesh): # Create the model with sharded parameters. with nn.logical_axis_rules(config.logical_axis_rules): sharded_state = create_sharded_state() diff --git a/src/MaxText/rl/train_rl.py b/src/MaxText/rl/train_rl.py index cbf185af1..6d4e79e34 100644 --- a/src/MaxText/rl/train_rl.py +++ b/src/MaxText/rl/train_rl.py @@ -92,7 +92,7 @@ def get_maxtext_model(config, devices=None): # load_parameters_path=/path/to/your/output/directory/0/items """ model, mesh = model_creation_utils.create_nnx_model(config, devices=devices) - with mesh: + with jax.set_mesh(mesh): tunix_model = TunixMaxTextAdapter(base_model=model) tunix_model.config = None return tunix_model, mesh diff --git a/src/MaxText/sft/sft_trainer.py b/src/MaxText/sft/sft_trainer.py index 3a5ecc4ca..6e5518744 100644 --- a/src/MaxText/sft/sft_trainer.py +++ b/src/MaxText/sft/sft_trainer.py @@ -160,7 +160,7 @@ def setup_trainer_state(mt_config, goodput_recorder=None): def train_model(mt_config, trainer, mesh): """Runs the SFT training loop in Tunix.""" - with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules): + with jax.set_mesh(mesh), nn_partitioning.axis_rules(mt_config.logical_axis_rules): trainer.train(trainer.data_hooks.train_data_iterator, trainer.data_hooks.eval_data_iterator) return trainer diff --git a/src/MaxText/sft_trainer.py b/src/MaxText/sft_trainer.py index c428e1402..bc60b32b1 100644 --- a/src/MaxText/sft_trainer.py +++ b/src/MaxText/sft_trainer.py @@ -75,7 +75,7 @@ def train_loop(config, recorder, state=None): config, model, mesh, state, state_mesh_shardings, train_step, eval_step, eval_data_iterator ) - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): shaped_batch = maxtext_utils.get_shaped_batch(config) compiled = p_train_step.lower(state, shaped_batch, init_rng).compile() compiled_stats = compiled.memory_analysis() @@ -99,7 +99,7 @@ def train_loop(config, recorder, state=None): # pylint: disable=not-callable nextrng = jax.jit(jax.random.fold_in)(init_rng, step) with maybe_record_goodput(recorder, GoodputEvent.STEP, step): - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): state, metrics = p_train_step(state, example_batch, nextrng) step_time_delta = datetime.datetime.now() - last_step_completion @@ -124,7 +124,7 @@ def train_loop(config, recorder, state=None): for eval_batch in eval_data_iterator: if config.eval_steps > 0 and eval_step_count >= config.eval_steps: break - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): eval_metrics = p_eval_step(state, eval_batch, nextrng) metric_logger.record_eval_metrics(step, metrics=eval_metrics) max_logging.log(f"Completed eval step {eval_step_count}") diff --git a/src/MaxText/train.py b/src/MaxText/train.py index fb5fcebdb..f5e8cf377 100644 --- a/src/MaxText/train.py +++ b/src/MaxText/train.py @@ -402,7 +402,7 @@ def train_loop(config, recorder, state=None): params_shardings, ) - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): shaped_batch = maxtext_utils.get_shaped_batch(config) if config.shard_optimizer_over_data: state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode) @@ -434,7 +434,7 @@ def train_loop(config, recorder, state=None): # pylint: disable=not-callable nextrng = jax.jit(jax.random.fold_in)(init_rng, step) with maybe_record_goodput(recorder, GoodputEvent.STEP, step): - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): if config.shard_optimizer_over_data: state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode) state, metrics = p_train_step(state, example_batch, nextrng) @@ -466,7 +466,7 @@ def train_loop(config, recorder, state=None): for eval_batch in eval_data_iterator: if config.eval_steps > 0 and eval_step_count >= config.eval_steps: break - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): eval_metrics = p_eval_step(state, eval_batch, nextrng) metric_logger.record_eval_metrics(step, metrics=eval_metrics) max_logging.log(f"Completed eval step {eval_step_count}") diff --git a/src/MaxText/train_compile.py b/src/MaxText/train_compile.py index 4b0d5e88d..0ce819c35 100644 --- a/src/MaxText/train_compile.py +++ b/src/MaxText/train_compile.py @@ -120,7 +120,7 @@ def jit_and_compile( logical_axis_rules, ): """Jit, lower, and compile func.""" - with mesh, logical_axis_rules: + with jax.set_mesh(mesh), logical_axis_rules: jitted = jax.jit( func, in_shardings=in_shardings, diff --git a/src/MaxText/train_utils.py b/src/MaxText/train_utils.py index ce465a56e..edb0ac0f5 100644 --- a/src/MaxText/train_utils.py +++ b/src/MaxText/train_utils.py @@ -193,7 +193,7 @@ def setup_train_loop(config, recorder, devices=None): ) # Apply reordering wrapper to data iterators if context parallelism is enabled - with mesh: + with jax.set_mesh(mesh): if context_parallel_size > 1 and config.context_parallel_load_balance: data_iterator = map(maxtext_utils.get_reorder_callable(context_parallel_size, config.shard_mode), data_iterator) if eval_data_iterator: diff --git a/tests/maxtext_utils_test.py b/tests/maxtext_utils_test.py index e9918b6ff..30a0032f8 100644 --- a/tests/maxtext_utils_test.py +++ b/tests/maxtext_utils_test.py @@ -318,52 +318,47 @@ def test_fully_sharded_2d(self): """ Tests that a 2D tensor fully sharded across both mesh axes passes the assertion. """ - # Activate the mesh context. - with self.mesh: - # Define a sharding spec that shards the first tensor dimension by the 'fsdp' mesh axis - # and the second dimension by the 'tensor' mesh axis. - pspec = PartitionSpec("fsdp", "tensor") - # Create a parameter and apply the sharding, ensuring it's distributed across all devices. - params = {"layer1": jax.device_put(jnp.ones((8, 8)), NamedSharding(self.mesh, pspec))} - - # Assert that the parameters are sufficiently sharded; this should pass with no error. - assert_params_sufficiently_sharded(params, self.mesh, tolerance=0.1) + # Define a sharding spec that shards the first tensor dimension by the 'fsdp' mesh axis + # and the second dimension by the 'tensor' mesh axis. + pspec = PartitionSpec("fsdp", "tensor") + # Create a parameter and apply the sharding, ensuring it's distributed across all devices. + params = {"layer1": jax.device_put(jnp.ones((8, 8)), NamedSharding(self.mesh, pspec))} + + # Assert that the parameters are sufficiently sharded; this should pass with no error. + assert_params_sufficiently_sharded(params, self.mesh, tolerance=0.1) def test_unsharded_fails(self): """ Tests that a completely unsharded (fully replicated) parameter fails the assertion. """ - with self.mesh: - # Create a parameter without any sharding specification. It will be replicated on all devices. - params = {"layer1": jnp.ones((8, 8))} + # Create a parameter without any sharding specification. It will be replicated on all devices. + params = {"layer1": jnp.ones((8, 8))} - # Expect an AssertionError because 100% of params are unsharded, exceeding the 10% tolerance. - with self.assertRaises(AssertionError): - assert_params_sufficiently_sharded(params, self.mesh, tolerance=0.1) + # Expect an AssertionError because 100% of params are unsharded, exceeding the 10% tolerance. + with self.assertRaises(AssertionError): + assert_params_sufficiently_sharded(params, self.mesh, tolerance=0.1) def test_mixed_sharding_fails(self): """ Tests that a mix of sharded and unsharded parameters fails when the unsharded portion exceeds the tolerance. """ - with self.mesh: - sharded_param = jax.device_put(jnp.ones((8, 8)), NamedSharding(self.mesh, PartitionSpec("fsdp", "tensor"))) - unsharded_param = jnp.ones((8, 8)) - params = {"layer1": sharded_param, "layer2": unsharded_param} + sharded_param = jax.device_put(jnp.ones((8, 8)), NamedSharding(self.mesh, PartitionSpec("fsdp", "tensor"))) + unsharded_param = jnp.ones((8, 8)) + params = {"layer1": sharded_param, "layer2": unsharded_param} - with self.assertRaises(AssertionError): - assert_params_sufficiently_sharded(params, self.mesh, tolerance=0.5) + with self.assertRaises(AssertionError): + assert_params_sufficiently_sharded(params, self.mesh, tolerance=0.5) def test_3d_tensor_sharded_on_fsdp_axis(self): """ Tests that a 3D tensor sharded only on a valid target axis ('fsdp') should fail. """ - with self.mesh: - pspec = PartitionSpec("fsdp", None, None) - params = {"conv3d_layer": jax.device_put(jnp.ones((8, 4, 4)), NamedSharding(self.mesh, pspec))} + pspec = PartitionSpec("fsdp", None, None) + params = {"conv3d_layer": jax.device_put(jnp.ones((8, 4, 4)), NamedSharding(self.mesh, pspec))} - with self.assertRaises(AssertionError): - assert_params_sufficiently_sharded(params, self.mesh, tolerance=0.2) + with self.assertRaises(AssertionError): + assert_params_sufficiently_sharded(params, self.mesh, tolerance=0.2) def test_multi_axis_sharding_pass(self): """ @@ -374,13 +369,12 @@ def test_multi_axis_sharding_pass(self): devices = np.array(jax.devices()).reshape((4, 1, 1, 1, 1)) mesh = Mesh(devices, self.mesh_axes) - with mesh: - # Shard across multiple axes, including the valid 'fsdp' axis. - pspec = PartitionSpec(("fsdp", "sequence"), "stage", ("tensor"), None) - params = {"complex_layer": jax.device_put(jnp.ones((8, 8, 2, 2)), NamedSharding(mesh, pspec))} + # Shard across multiple axes, including the valid 'fsdp' axis. + pspec = PartitionSpec(("fsdp", "sequence"), "stage", ("tensor"), None) + params = {"complex_layer": jax.device_put(jnp.ones((8, 8, 2, 2)), NamedSharding(mesh, pspec))} - # This should pass because 'fsdp' is a valid sharding axis being used. - assert_params_sufficiently_sharded(params, mesh, tolerance=0.05) + # This should pass because 'fsdp' is a valid sharding axis being used. + assert_params_sufficiently_sharded(params, mesh, tolerance=0.05) def test_multi_axis_not_sharded_fails(self): """ @@ -389,12 +383,11 @@ def test_multi_axis_not_sharded_fails(self): """ devices = np.array(jax.devices()).reshape((4, 1, 1, 1, 1)) mesh = Mesh(devices, self.mesh_axes) - with mesh: - pspec = PartitionSpec(("sequence", "context"), "stage", "tensor", None) - params = {"complex_layer": jax.device_put(jnp.ones((8, 8, 2, 2)), NamedSharding(mesh, pspec))} + pspec = PartitionSpec(("sequence", "context"), "stage", "tensor", None) + params = {"complex_layer": jax.device_put(jnp.ones((8, 8, 2, 2)), NamedSharding(mesh, pspec))} - with self.assertRaises(AssertionError): - assert_params_sufficiently_sharded(params, mesh, tolerance=0.05) + with self.assertRaises(AssertionError): + assert_params_sufficiently_sharded(params, mesh, tolerance=0.05) def test_multi_axis_mixed_sharding_fails(self): """ @@ -402,17 +395,16 @@ def test_multi_axis_mixed_sharding_fails(self): """ devices = np.array(jax.devices()).reshape((4, 1, 1, 1, 1)) mesh = Mesh(devices, self.mesh_axes) - with mesh: - sharded_pspec = PartitionSpec(("fsdp", "sequence"), "stage", ("tensor"), None) - sharded_param = jax.device_put(jnp.ones((8, 8, 2, 2)), NamedSharding(mesh, sharded_pspec)) - unsharded_param = jnp.ones((8, 8, 2, 2)) - params = { - "sharded_layer": sharded_param, - "unsharded_layer": unsharded_param, - } + sharded_pspec = PartitionSpec(("fsdp", "sequence"), "stage", ("tensor"), None) + sharded_param = jax.device_put(jnp.ones((8, 8, 2, 2)), NamedSharding(mesh, sharded_pspec)) + unsharded_param = jnp.ones((8, 8, 2, 2)) + params = { + "sharded_layer": sharded_param, + "unsharded_layer": unsharded_param, + } - with self.assertRaises(AssertionError): - assert_params_sufficiently_sharded(params, mesh, tolerance=0.5) + with self.assertRaises(AssertionError): + assert_params_sufficiently_sharded(params, mesh, tolerance=0.5) class TestAssert_Formatted_sharding_annotations(unittest.TestCase): @@ -435,15 +427,14 @@ def test_multi_axis_mixed_formating(self): """ Tests a mix of sharded and unsharded tensors on a complex mesh fails. """ - with self.mesh: - sharded_pspec = PartitionSpec(("fsdp", "sequence"), "stage", ("tensor"), None) - sharded_param = jax.device_put(jnp.ones((8, 8, 2, 2)), NamedSharding(self.mesh, sharded_pspec)) - unsharded_param = jnp.ones((8, 8, 2, 2)) - params = { - "sharded_layer": sharded_param, - "unsharded_layer": unsharded_param, - } - self.assertIsNotNone(get_formatted_sharding_annotations(params, self.mesh)) + sharded_pspec = PartitionSpec(("fsdp", "sequence"), "stage", ("tensor"), None) + sharded_param = jax.device_put(jnp.ones((8, 8, 2, 2)), NamedSharding(self.mesh, sharded_pspec)) + unsharded_param = jnp.ones((8, 8, 2, 2)) + params = { + "sharded_layer": sharded_param, + "unsharded_layer": unsharded_param, + } + self.assertIsNotNone(get_formatted_sharding_annotations(params, self.mesh)) class TestPromptLogprobsFromPrefill(unittest.TestCase):