File tree Expand file tree Collapse file tree 1 file changed +8
-1
lines changed
tests/unittest/_torch/modules Expand file tree Collapse file tree 1 file changed +8
-1
lines changed Original file line number Diff line number Diff line change 2424 skip_pre_hopper )
2525
2626from tensorrt_llm ._torch .autotuner import AutoTuner , autotune
27+ from tensorrt_llm ._torch .distributed import MPIDist , TorchDist
2728from tensorrt_llm ._torch .model_config import ModelConfig
2829from tensorrt_llm ._torch .modules .fused_moe .fused_moe_cute_dsl import \
2930 CuteDslFusedMoE
4445from tensorrt_llm ._torch .modules .fused_moe .fused_moe_triton import \
4546 IS_TRITON_KERNELS_AVAILABLE
4647from tensorrt_llm ._torch .modules .gated_mlp import GatedMLP
47- from tensorrt_llm ._utils import get_sm_version , mpi_rank
48+ from tensorrt_llm ._utils import get_sm_version , mpi_disabled , mpi_rank
4849from tensorrt_llm .mapping import Mapping
4950from tensorrt_llm .models .modeling_utils import QuantAlgo , QuantConfig
5051
@@ -104,6 +105,12 @@ def test_fused_moe(moe_backend,
104105
105106 mapping = mapping or Mapping ()
106107 mapping .rank = mpi_rank ()
108+ if mpi_disabled ():
109+ dist = TorchDist (mapping = mapping )
110+ else :
111+ dist = MPIDist (mapping = mapping )
112+
113+ AutoTuner .get ().setup_distributed_state (mapping , dist )
107114
108115 torch .cuda .set_device (mapping .rank )
109116
You can’t perform that action at this time.
0 commit comments