@@ -653,6 +653,69 @@ def test_megablox_context_parallelism(self):
653653 actual_output , _ = self .get_moe_output (variables , hidden_states , cfg , mesh )
654654 self .assertTrue (jax .numpy .allclose (expected_output , actual_output , rtol = 1e-02 , atol = 1e-02 , equal_nan = False ))
655655
656+ @pytest .mark .tpu_only
657+ def test_megablox_expert_context_parallelism (self ):
658+ cfg = pyconfig .initialize (
659+ [None , os .path .join (MAXTEXT_PKG_DIR , "configs" , "base.yml" )],
660+ run_name = "moe_block_megablox_ep_cp_test" ,
661+ enable_checkpointing = False ,
662+ model_name = "mixtral-8x7b" ,
663+ dtype = "bfloat16" ,
664+ megablox = True ,
665+ sparse_matmul = True ,
666+ per_device_batch_size = 4 ,
667+ ici_context_parallelism = 2 ,
668+ ici_expert_parallelism = 2 ,
669+ packing = False ,
670+ )
671+
672+ rng = jax .random .PRNGKey (2345 )
673+ rng_model , rng_hidden_states = jax .random .split (rng )
674+ device_count = jax .device_count ()
675+ hidden_states = jax .random .uniform (
676+ rng_hidden_states ,
677+ (int (cfg .per_device_batch_size ) * device_count , cfg .max_target_length , cfg .base_emb_dim ),
678+ dtype = cfg .dtype ,
679+ )
680+
681+ devices_array = maxtext_utils .create_device_mesh (cfg )
682+ mesh = Mesh (devices_array , cfg .mesh_axes )
683+ with nn_partitioning .axis_rules (cfg .logical_axis_rules ):
684+ variables , expected_output = self .get_expected_output (rng_model , hidden_states , cfg , mesh )
685+ actual_output , _ = self .get_moe_output (variables , hidden_states , cfg , mesh )
686+ self .assertTrue (jax .numpy .allclose (expected_output , actual_output , rtol = 1e-02 , atol = 1e-02 , equal_nan = False ))
687+
688+ @pytest .mark .tpu_only
689+ def test_megablox_expert_tensor_parallelism (self ):
690+ cfg = pyconfig .initialize (
691+ [None , os .path .join (MAXTEXT_PKG_DIR , "configs" , "base.yml" )],
692+ run_name = "moe_block_megablox_ep_tp_test" ,
693+ enable_checkpointing = False ,
694+ model_name = "mixtral-8x7b" ,
695+ dtype = "bfloat16" ,
696+ megablox = True ,
697+ sparse_matmul = True ,
698+ per_device_batch_size = 4 ,
699+ ici_tensor_parallelism = 2 ,
700+ ici_expert_parallelism = 2 ,
701+ )
702+
703+ rng = jax .random .PRNGKey (2345 )
704+ rng_model , rng_hidden_states = jax .random .split (rng )
705+ device_count = jax .device_count ()
706+ hidden_states = jax .random .uniform (
707+ rng_hidden_states ,
708+ (int (cfg .per_device_batch_size ) * device_count , cfg .max_target_length , cfg .base_emb_dim ),
709+ dtype = cfg .dtype ,
710+ )
711+
712+ devices_array = maxtext_utils .create_device_mesh (cfg )
713+ mesh = Mesh (devices_array , cfg .mesh_axes )
714+ with nn_partitioning .axis_rules (cfg .logical_axis_rules ):
715+ variables , expected_output = self .get_expected_output (rng_model , hidden_states , cfg , mesh )
716+ actual_output , _ = self .get_moe_output (variables , hidden_states , cfg , mesh )
717+ self .assertTrue (jax .numpy .allclose (expected_output , actual_output , rtol = 1e-02 , atol = 1e-02 , equal_nan = False ))
718+
656719 def test_random_routing (self ):
657720 bs , seq_len , num_experts , num_experts_per_tok = 12 , 1024 , 8 , 2
658721 rng = jax .random .PRNGKey (0 )
0 commit comments