Skip to content

Commit 33cb190

Browse files
committed
update buffer size for EP sharding
1 parent 6191433 commit 33cb190

File tree

2 files changed

+64
-6
lines changed

2 files changed

+64
-6
lines changed

src/MaxText/layers/moe.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,12 +1036,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
10361036
# This would result in num_expert_shards * input_size * experts_per_shard assignments. However, if
10371037
# experts_per_shard > num_experts_per_tok we cannot assign more than num_experts_per_tok to all of the inputs.
10381038
max_local_experts_per_tok = min(local_expert_size, self.config.num_experts_per_tok)
1039-
buffer_size = int(
1040-
num_expert_parallelism
1041-
* self.config.per_device_batch_size
1042-
* self.config.max_target_length
1043-
* max_local_experts_per_tok
1044-
)
1039+
buffer_size = int(num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok)
10451040
output_shape = jnp.zeros((buffer_size, self.config.emb_dim), dtype=x.dtype)
10461041

10471042
x = jax.lax.ragged_all_to_all(

tests/moe_test.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,69 @@ def test_megablox_context_parallelism(self):
653653
actual_output, _ = self.get_moe_output(variables, hidden_states, cfg, mesh)
654654
self.assertTrue(jax.numpy.allclose(expected_output, actual_output, rtol=1e-02, atol=1e-02, equal_nan=False))
655655

656+
@pytest.mark.tpu_only
657+
def test_megablox_expert_context_parallelism(self):
658+
cfg = pyconfig.initialize(
659+
[None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")],
660+
run_name="moe_block_megablox_ep_cp_test",
661+
enable_checkpointing=False,
662+
model_name="mixtral-8x7b",
663+
dtype="bfloat16",
664+
megablox=True,
665+
sparse_matmul=True,
666+
per_device_batch_size=4,
667+
ici_context_parallelism=2,
668+
ici_expert_parallelism=2,
669+
packing=False,
670+
)
671+
672+
rng = jax.random.PRNGKey(2345)
673+
rng_model, rng_hidden_states = jax.random.split(rng)
674+
device_count = jax.device_count()
675+
hidden_states = jax.random.uniform(
676+
rng_hidden_states,
677+
(int(cfg.per_device_batch_size) * device_count, cfg.max_target_length, cfg.base_emb_dim),
678+
dtype=cfg.dtype,
679+
)
680+
681+
devices_array = maxtext_utils.create_device_mesh(cfg)
682+
mesh = Mesh(devices_array, cfg.mesh_axes)
683+
with nn_partitioning.axis_rules(cfg.logical_axis_rules):
684+
variables, expected_output = self.get_expected_output(rng_model, hidden_states, cfg, mesh)
685+
actual_output, _ = self.get_moe_output(variables, hidden_states, cfg, mesh)
686+
self.assertTrue(jax.numpy.allclose(expected_output, actual_output, rtol=1e-02, atol=1e-02, equal_nan=False))
687+
688+
@pytest.mark.tpu_only
689+
def test_megablox_expert_tensor_parallelism(self):
690+
cfg = pyconfig.initialize(
691+
[None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")],
692+
run_name="moe_block_megablox_ep_tp_test",
693+
enable_checkpointing=False,
694+
model_name="mixtral-8x7b",
695+
dtype="bfloat16",
696+
megablox=True,
697+
sparse_matmul=True,
698+
per_device_batch_size=4,
699+
ici_tensor_parallelism=2,
700+
ici_expert_parallelism=2,
701+
)
702+
703+
rng = jax.random.PRNGKey(2345)
704+
rng_model, rng_hidden_states = jax.random.split(rng)
705+
device_count = jax.device_count()
706+
hidden_states = jax.random.uniform(
707+
rng_hidden_states,
708+
(int(cfg.per_device_batch_size) * device_count, cfg.max_target_length, cfg.base_emb_dim),
709+
dtype=cfg.dtype,
710+
)
711+
712+
devices_array = maxtext_utils.create_device_mesh(cfg)
713+
mesh = Mesh(devices_array, cfg.mesh_axes)
714+
with nn_partitioning.axis_rules(cfg.logical_axis_rules):
715+
variables, expected_output = self.get_expected_output(rng_model, hidden_states, cfg, mesh)
716+
actual_output, _ = self.get_moe_output(variables, hidden_states, cfg, mesh)
717+
self.assertTrue(jax.numpy.allclose(expected_output, actual_output, rtol=1e-02, atol=1e-02, equal_nan=False))
718+
656719
def test_random_routing(self):
657720
bs, seq_len, num_experts, num_experts_per_tok = 12, 1024, 8, 2
658721
rng = jax.random.PRNGKey(0)

0 commit comments

Comments
 (0)