@@ -99,7 +99,9 @@ def max_position_embeddings(self) -> int:
9999
100100
101101all_scenarios = [
102- Scenario (batch = 1 , ctx_len = 64 ),
102+ # Scenario(batch=1, ctx_len=64),
103+ # Scenario(batch=1, ctx_len=64),
104+ Scenario (batch = 1 , ctx_len = 128 ),
103105 Scenario (batch = 1 , ctx_len = 512 ),
104106 Scenario (batch = 1 , ctx_len = 1024 ),
105107 Scenario (batch = 1 , ctx_len = 2048 ),
@@ -414,7 +416,7 @@ def rotate_half_inv(x):
414416 )
415417
416418 mapping = Mapping (
417- world_size = world_size , rank = rank , cp_size = world_size , cp_config = {"cp_type" : CpType .HELIX }
419+ world_size = world_size , rank = rank , cp_size = world_size , cp_config = {"cp_type" : CpType .HELIX , "tokens_per_block" : 32 }
418420 )
419421 # use cp_allgather here to broadcast from rank 0 to all other ranks
420422 ret_all = cp_allgather (ret , mapping = mapping , dim = 0 )
@@ -837,15 +839,15 @@ def _run_single_rank(func, *args, **kwargs):
837839 raise Exception (f"\n \n Error occurred. Original traceback is\n { tb } \n " )
838840
839841
840- @pytest .mark .skipif (torch .cuda .device_count () < 2 , reason = "needs 2 GPUs to run this test" )
842+ @pytest .mark .skipif (torch .cuda .device_count () < 8 , reason = "needs 8 GPUs to run this test" )
841843@pytest .mark .parametrize ("scenario" , test_scenarios , ids = lambda x : f"scenario: { x } " )
842844def test_mla_helix_distributed (
843845 scenario : Scenario ,
844846 gen_steps : Optional [int ] = None ,
845847 max_mismatch_ratio : float = 0.02 ,
846848 mismatch_ratios : Optional [List [float ]] = None ,
847849):
848- world_size = 2
850+ world_size = 8
849851 gen_steps = scenario .ref_steps if gen_steps is None else gen_steps
850852 with MPIPoolExecutor (max_workers = world_size ) as executor :
851853 results = executor .map (
0 commit comments