Skip to content

Commit b85c447

Browse files
authored
[https://nvbugs/5784543][fix] Setup dist before using autotuner. (#10491)
Signed-off-by: Yuxian Qiu <[email protected]>
1 parent 09d9878 commit b85c447

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

tests/unittest/_torch/modules/test_fused_moe.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
skip_pre_hopper)
2525

2626
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
27+
from tensorrt_llm._torch.distributed import MPIDist, TorchDist
2728
from tensorrt_llm._torch.model_config import ModelConfig
2829
from tensorrt_llm._torch.modules.fused_moe.fused_moe_cute_dsl import \
2930
CuteDslFusedMoE
@@ -44,7 +45,7 @@
4445
from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import \
4546
IS_TRITON_KERNELS_AVAILABLE
4647
from tensorrt_llm._torch.modules.gated_mlp import GatedMLP
47-
from tensorrt_llm._utils import get_sm_version, mpi_rank
48+
from tensorrt_llm._utils import get_sm_version, mpi_disabled, mpi_rank
4849
from tensorrt_llm.mapping import Mapping
4950
from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig
5051

@@ -104,6 +105,12 @@ def test_fused_moe(moe_backend,
104105

105106
mapping = mapping or Mapping()
106107
mapping.rank = mpi_rank()
108+
if mpi_disabled():
109+
dist = TorchDist(mapping=mapping)
110+
else:
111+
dist = MPIDist(mapping=mapping)
112+
113+
AutoTuner.get().setup_distributed_state(mapping, dist)
107114

108115
torch.cuda.set_device(mapping.rank)
109116

0 commit comments

Comments
 (0)