Skip to content

Commit b3d794f

Browse files
committed
Fix CI error.
Signed-off-by: Yuxian Qiu <[email protected]>
1 parent 85c7acc commit b3d794f

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

examples/layer_wise_benchmarks/run.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,10 @@
1010
import yaml
1111

1212
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
13-
from tensorrt_llm._torch.distributed import MPIDist, TorchDist
1413
from tensorrt_llm._torch.modules.fused_moe.fused_moe_cutlass import CutlassFusedMoE
1514
from tensorrt_llm._torch.modules.fused_moe.interface import AlltoallMethodType
1615
from tensorrt_llm._torch.modules.multi_stream_utils import with_multi_stream
17-
from tensorrt_llm._utils import local_mpi_rank, mpi_disabled, mpi_rank, mpi_world_size
16+
from tensorrt_llm._utils import local_mpi_rank, mpi_rank, mpi_world_size
1817
from tensorrt_llm.logger import logger
1918
from tensorrt_llm.tools.layer_wise_benchmarks import BalanceMethod, get_runner_cls, mark_ranges
2019

@@ -174,8 +173,7 @@ def comma_separated_floats(s):
174173
)
175174
if args.enable_autotuner:
176175
cache_path = os.getenv("TLLM_AUTOTUNER_CACHE_PATH") or None
177-
dist = TorchDist(mapping=mapping) if mpi_disabled() else MPIDist(mapping=mapping)
178-
AutoTuner.get().setup_distributed_state(mapping, dist)
176+
AutoTuner.get().setup_distributed_state(mapping)
179177
with autotune(cache_path=cache_path):
180178
run_pack()
181179
else:

tests/unittest/_torch/misc/test_autotuner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
FakeTensor, OptimizationProfile,
1818
StaticDim, TunableRunner,
1919
TuningConfig, autotune)
20+
from tensorrt_llm._torch.distributed import Distributed
2021
from tensorrt_llm._torch.utils import (get_power_of_2_num_tokens_buckets,
2122
next_positive_power_of_2)
2223
from tensorrt_llm.bindings.internal.runtime import delay_kernel
@@ -718,6 +719,7 @@ def _distributed_worker_function(world_size, strategy):
718719
rank=rank,
719720
tp_size=world_size,
720721
pp_size=1)
722+
dist = Distributed.get(mapping)
721723

722724
tuner = AutoTuner.get()
723725
tuner.clear_cache()

0 commit comments

Comments
 (0)