diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 6c0b5aedd27..e1e580ab9fa 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -491,8 +491,6 @@ triton_server/test_triton.py::test_llava_onevision[llava_onevision] SKIP (https: triton_server/test_triton.py::test_gpt_ib_lad[gpt-ib-lad] SKIP (https://nvbugs/5775223) accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_trtllm-torch_compile=True] SKIP (https://nvbugs/5740377) unittest/_torch/modules/test_fused_moe.py::test_fused_moe_fp8_blockwise_cute_dsl_multi_gpu[MoEWeightLoadingMode.FUSED_GATE_UP_PROJ-DefaultMoeRoutingMethod-1] SKIP (https://nvbugs/5775256) -unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py::test_ep_shard[3-2] SKIP (https://nvbugs/5777041) -unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py::test_ep_shard[8-2] SKIP (https://nvbugs/5777041) disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[llama-3.1-8b-instruct-hf-fp8] SKIP (https://nvbugs/5769890) disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[DeepSeek-V3-Lite-bf16] SKIP (https://nvbugs/5769890) disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[llama-v3-8b-hf] SKIP (https://nvbugs/5769890,https://nvbugs/5748683) diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py index 8c034799ad7..a1e642db318 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py @@ -31,10 +31,6 @@ def _run_ep_shard_job(num_experts: int, rank: int, world_size: int) -> None: ).to(device=device, dtype=torch.bfloat16) x = model.get_input(device=device, dtype=torch.bfloat16) - if world_size > num_experts: - print(f"world_size {world_size} > num_experts {num_experts}, skipping test") - return - def _get_expected_num_params(rank: int, world_size: int, num_p_og: int) -> int: if world_size <= 1: return num_p_og @@ -141,9 +137,11 @@ def _run_pattern_detection_job(num_experts: int, rank: int, world_size: int) -> run_sharding_pattern_detection_test(detected_transformations, expected_transformations) -@pytest.mark.parametrize("device_count", get_device_counts()) +@pytest.mark.parametrize("device_count", get_device_counts([2, 8])) @pytest.mark.parametrize("num_experts", [3, 8]) def test_ep_shard(device_count: int, num_experts: int): + if device_count > num_experts: + pytest.skip(f"world_size {device_count} > num_experts {num_experts}") dist_common.spawn_multiprocess_job( job=partial(_run_ep_shard_job, num_experts), size=device_count,