Skip to content

Commit 115bc25

Browse files
committed
unit test works well for pure CP on 4 and 8 GPUs
1 parent a328124 commit 115bc25

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

tests/unittest/_torch/modules/test_mla_helix.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ def max_position_embeddings(self) -> int:
9999

100100

101101
all_scenarios = [
102-
Scenario(batch=1, ctx_len=64),
102+
# Scenario(batch=1, ctx_len=64),
103+
# Scenario(batch=1, ctx_len=64),
104+
Scenario(batch=1, ctx_len=128),
103105
Scenario(batch=1, ctx_len=512),
104106
Scenario(batch=1, ctx_len=1024),
105107
Scenario(batch=1, ctx_len=2048),
@@ -414,7 +416,7 @@ def rotate_half_inv(x):
414416
)
415417

416418
mapping = Mapping(
417-
world_size=world_size, rank=rank, cp_size=world_size, cp_config={"cp_type": CpType.HELIX}
419+
world_size=world_size, rank=rank, cp_size=world_size, cp_config={"cp_type": CpType.HELIX, "tokens_per_block": 32}
418420
)
419421
# use cp_allgather here to broadcast from rank 0 to all other ranks
420422
ret_all = cp_allgather(ret, mapping=mapping, dim=0)
@@ -837,15 +839,15 @@ def _run_single_rank(func, *args, **kwargs):
837839
raise Exception(f"\n\nError occurred. Original traceback is\n{tb}\n")
838840

839841

840-
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="needs 2 GPUs to run this test")
842+
@pytest.mark.skipif(torch.cuda.device_count() < 8, reason="needs 8 GPUs to run this test")
841843
@pytest.mark.parametrize("scenario", test_scenarios, ids=lambda x: f"scenario: {x}")
842844
def test_mla_helix_distributed(
843845
scenario: Scenario,
844846
gen_steps: Optional[int] = None,
845847
max_mismatch_ratio: float = 0.02,
846848
mismatch_ratios: Optional[List[float]] = None,
847849
):
848-
world_size = 2
850+
world_size = 8
849851
gen_steps = scenario.ref_steps if gen_steps is None else gen_steps
850852
with MPIPoolExecutor(max_workers=world_size) as executor:
851853
results = executor.map(

0 commit comments

Comments
 (0)