Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions examples/layer_wise_benchmarks/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
import yaml

from tensorrt_llm._torch.autotuner import AutoTuner, autotune
from tensorrt_llm._torch.distributed import MPIDist, TorchDist
from tensorrt_llm._torch.modules.multi_stream_utils import with_multi_stream
from tensorrt_llm._utils import local_mpi_rank, mpi_disabled, mpi_rank, mpi_world_size
from tensorrt_llm._utils import local_mpi_rank, mpi_rank, mpi_world_size
from tensorrt_llm.logger import logger
from tensorrt_llm.tools.layer_wise_benchmarks import BalanceMethod, Runner, mark_ranges

Expand Down Expand Up @@ -192,8 +191,7 @@ def comma_separated_floats(s):
)
if args.enable_autotuner:
cache_path = os.getenv("TLLM_AUTOTUNER_CACHE_PATH") or None
dist = TorchDist(mapping=mapping) if mpi_disabled() else MPIDist(mapping=mapping)
AutoTuner.get().setup_distributed_state(mapping, dist)
AutoTuner.get().setup_distributed_state(mapping)
with autotune(cache_path=cache_path):
run_pack()
else:
Expand Down
15 changes: 3 additions & 12 deletions tensorrt_llm/_ipc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,12 @@
import sys
from typing import List, Tuple

from tensorrt_llm._utils import mpi_disabled

try:
from cuda.bindings import driver as cuda
from cuda.bindings import runtime as cudart
except ImportError:
from cuda import cuda, cudart

from ._utils import mpi_comm
from .logger import logger
from .mapping import Mapping

Expand Down Expand Up @@ -107,15 +104,9 @@ def align_size(size, alignment):
size += alignment - (size % alignment)
return size

if mpi_disabled():
from tensorrt_llm._utils import torch_comm
from tensorrt_llm._torch.distributed.communicator import Distributed

allgather = torch_comm().tp_allgather
else:
comm = mpi_comm().Split(
mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank
)
allgather = comm.allgather
dist = Distributed.get(mapping)

# see allocateIpcMemory in cpp/tensorrt_llm/runtime/ipcUtils.cpp for alignment reason
# 1 << 21 is 2MB
Expand All @@ -126,7 +117,7 @@ def align_size(size, alignment):
_raise_if_error(cudart.cudaMemset(local_ptr, 0, aligned_size)[0])
error, local_handle = cudart.cudaIpcGetMemHandle(local_ptr)
_raise_if_error(error)
handles_reserved = allgather(local_handle.reserved)
handles_reserved = dist.tp_allgather(local_handle.reserved)

handles = []
for reserved in handles_reserved:
Expand Down
6 changes: 1 addition & 5 deletions tensorrt_llm/_torch/auto_deploy/distributed/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch.distributed as dist
import torch.multiprocessing as mp

from tensorrt_llm._utils import get_free_port as _get_free_port
from tensorrt_llm._utils import get_free_port

from ..utils.logger import ad_logger

Expand Down Expand Up @@ -69,10 +69,6 @@ def all_gather_object(object_list, object, group=None):
return dist.all_gather_object(object_list, object, group=group)


def get_free_port():
return _get_free_port()


def get_world_size() -> int:
return dist.get_world_size()

Expand Down
24 changes: 12 additions & 12 deletions tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@
)
from tensorrt_llm.llmapi.tokenizer import TokenizerBase

from ...._utils import mpi_rank, mpi_world_size
from ...._utils import get_free_port, mpi_rank, mpi_world_size
from ....bindings.internal.batch_manager import CacheType
from ....mapping import Mapping
from ...distributed import MPIDist
from ...distributed import Distributed
from ...pyexecutor.model_engine import ModelEngine, PyTorchModelEngine
from ...pyexecutor.py_executor import PyExecutor
from ...pyexecutor.resource_manager import (
Expand All @@ -68,7 +68,7 @@
SimpleScheduler,
)
from ..custom_ops.attention_interface import SequenceInfo
from ..distributed import common as dist
from ..distributed.common import initialize_or_skip
from ..llm_args import LlmArgs
from ..transform.optimizer import InferenceOptimizer
from ..utils.logger import ad_logger
Expand Down Expand Up @@ -880,15 +880,15 @@ def share_lm_head_weights_with_draft(


def create_draft_model_engine_maybe(
ad_config: LlmArgs, target_engine: ADEngine, dist_mapping: Mapping, mpi_dist: MPIDist
ad_config: LlmArgs, target_engine: ADEngine, dist_mapping: Mapping, dist: Distributed
) -> Optional[PyTorchModelEngine]:
"""Create a draft model engine for speculative decoding.

Args:
ad_config: The AutoDeploy LLM configuration
engine: The target model engine (ADEngine)
dist_mapping: The distributed mapping configuration
mpi_dist: The MPI distribution object
dist: The distribution object

Returns:
PyTorchModelEngine configured as a draft model, or None if not needed
Expand Down Expand Up @@ -925,7 +925,7 @@ def create_draft_model_engine_maybe(
llm_args=draft_llm_args,
mapping=dist_mapping,
attn_runtime_features=attn_runtime_features,
dist=mpi_dist,
dist=dist,
spec_config=draft_spec_config,
is_draft_model=True,
drafting_loop_wrapper=drafting_loop_wrapper,
Expand Down Expand Up @@ -1004,14 +1004,14 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
world_size = mpi_world_size()
rank = mpi_rank()
dist_mapping = Mapping(rank=rank, world_size=world_size, tp_size=world_size)
mpi_dist = MPIDist(dist_mapping)
dist = Distributed.get(dist_mapping)
ad_logger.set_rank(rank)
torch.cuda.set_device(rank)
port = mpi_dist.broadcast(dist.get_free_port()) # use MPI broadcast to pick a free port
dist.initialize_or_skip(rank, world_size, port)
port = dist.broadcast(get_free_port()) # use MPI broadcast to pick a free port
initialize_or_skip(rank, world_size, port)

# Setup AutoTuner with distributed state for allreduce autotuning
AutoTuner.get().setup_distributed_state(dist_mapping, mpi_dist)
AutoTuner.get().setup_distributed_state(dist_mapping)

# some config
assert ad_config.max_beam_width <= 1, "_autodeploy + beam_search is not supported"
Expand Down Expand Up @@ -1044,7 +1044,7 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
)

draft_model_engine = create_draft_model_engine_maybe(
ad_config=ad_config, target_engine=engine, dist_mapping=dist_mapping, mpi_dist=mpi_dist
ad_config=ad_config, target_engine=engine, dist_mapping=dist_mapping, dist=dist
)

spec_resource_manager = (
Expand Down Expand Up @@ -1171,7 +1171,7 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
scheduler,
model_engine=engine,
sampler=sampler,
dist=mpi_dist,
dist=dist,
max_num_sequences=max_num_sequences,
disable_overlap_scheduler=ad_config.disable_overlap_scheduler,
max_input_len=ad_config.max_input_len,
Expand Down
12 changes: 7 additions & 5 deletions tensorrt_llm/_torch/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,9 +1072,7 @@ def pure_profile(stream: torch.cuda.Stream, repeat: int):
stream.synchronize()
if tuning_config.distributed_tuning_strategy == DistributedTuningStrategy.MERGE:
# Currently only AllReduce will use this strategy, and only MPI parallel will enable tuning.
# TODO: Unified tp barrier for both MPIDist and TorchDist.
if hasattr(self._dist, "tp_comm"):
self._dist.tp_comm.barrier()
self._dist.tp_barrier()

# Delay the profiled kernel launch to eliminate affects of host time overhead in profiling.
if use_cuda_graph:
Expand Down Expand Up @@ -1495,10 +1493,14 @@ def _cudaGetErrorEnum(self, error) -> str:
else:
raise RuntimeError("Unknown error type: {}".format(error))

def setup_distributed_state(self, mapping: Mapping, dist: Distributed):
def setup_distributed_state(self,
mapping: Mapping,
dist: Optional[Distributed] = ...):
"""Setup distributed communication state for autotuning."""
self.mapping = mapping
self._dist = dist
# Create dist only when dist is not provided.
# Use the provided dist even if it is None. This is useful for testing.
self._dist = Distributed.get(mapping) if dist is ... else dist
self._debug_logger(
f"[AutoTuner] Whether using distributed tuning: {self._is_distributed()}"
)
Expand Down
Loading