Skip to content

Commit 2057019

Browse files
committed
make use_nccl_for_helix bool
1 parent 31325ee commit 2057019

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

tests/unittest/_torch/modules/test_mla_helix.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -838,20 +838,20 @@ def _run_single_rank(func, *args, **kwargs):
838838

839839
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="needs 2 GPUs to run this test")
840840
@pytest.mark.parametrize("scenario", test_scenarios, ids=lambda x: f"scenario: {x}")
841-
@pytest.mark.parametrize("use_nccl_for_helix", ["0", "1"], ids=["fifo", "nccl"])
841+
@pytest.mark.parametrize("use_nccl_for_helix", [False, True], ids=["fifo", "nccl"])
842842
def test_mla_helix_distributed(
843843
scenario: Scenario,
844-
use_nccl_for_helix: str,
844+
use_nccl_for_helix: bool,
845845
gen_steps: Optional[int] = None,
846846
max_mismatch_ratio: float = 0.02,
847847
mismatch_ratios: Optional[List[float]] = None,
848848
):
849849
# Set environment variable to control which codepath is used
850850
old_env_value = os.environ.get("TRTLLM_USE_NCCL_FOR_HELIX")
851-
os.environ["TRTLLM_USE_NCCL_FOR_HELIX"] = use_nccl_for_helix
851+
os.environ["TRTLLM_USE_NCCL_FOR_HELIX"] = "1" if use_nccl_for_helix else "0"
852852

853853
world_size = 2
854-
print(f"Testing with TRTLLM_USE_NCCL_FOR_HELIX={use_nccl_for_helix}.")
854+
print(f"Testing with TRTLLM_USE_NCCL_FOR_HELIX={'1' if use_nccl_for_helix else '0'}.")
855855
gen_steps = scenario.ref_steps if gen_steps is None else gen_steps
856856
try:
857857
with MPIPoolExecutor(max_workers=world_size) as executor:
@@ -873,10 +873,10 @@ def test_mla_helix_distributed(
873873

874874

875875
if __name__ == "__main__":
876-
for use_nccl in ["0", "1"]:
877-
nccl_mode = "NCCL" if use_nccl == "1" else "FIFO"
876+
for use_nccl in [False, True]:
877+
nccl_mode = "NCCL" if use_nccl else "FIFO"
878878
print(f"\n{'=' * 60}")
879-
print(f"Testing with TRTLLM_USE_NCCL_FOR_HELIX={use_nccl} ({nccl_mode} mode)")
879+
print(f"Testing with TRTLLM_USE_NCCL_FOR_HELIX={'1' if use_nccl else '0'} ({nccl_mode} mode)")
880880
print(f"{'=' * 60}\n")
881881
for scenario in all_scenarios[:11]:
882882
timing_steps = 256

0 commit comments

Comments
 (0)