diff --git a/examples/layer_wise_benchmarks/run.py b/examples/layer_wise_benchmarks/run.py index d84525c1d33..35f4b89c2f4 100644 --- a/examples/layer_wise_benchmarks/run.py +++ b/examples/layer_wise_benchmarks/run.py @@ -10,11 +10,10 @@ import yaml from tensorrt_llm._torch.autotuner import AutoTuner, autotune -from tensorrt_llm._torch.distributed import MPIDist, TorchDist from tensorrt_llm._torch.modules.fused_moe.fused_moe_cutlass import CutlassFusedMoE from tensorrt_llm._torch.modules.fused_moe.interface import AlltoallMethodType from tensorrt_llm._torch.modules.multi_stream_utils import with_multi_stream -from tensorrt_llm._utils import local_mpi_rank, mpi_disabled, mpi_rank, mpi_world_size +from tensorrt_llm._utils import local_mpi_rank, mpi_rank, mpi_world_size from tensorrt_llm.logger import logger from tensorrt_llm.tools.layer_wise_benchmarks import BalanceMethod, get_runner_cls, mark_ranges @@ -174,8 +173,7 @@ def comma_separated_floats(s): ) if args.enable_autotuner: cache_path = os.getenv("TLLM_AUTOTUNER_CACHE_PATH") or None - dist = TorchDist(mapping=mapping) if mpi_disabled() else MPIDist(mapping=mapping) - AutoTuner.get().setup_distributed_state(mapping, dist) + AutoTuner.get().setup_distributed_state(mapping) with autotune(cache_path=cache_path): run_pack() else: diff --git a/tensorrt_llm/_ipc_utils.py b/tensorrt_llm/_ipc_utils.py index 1d8e911633b..4d5ecefc9a9 100644 --- a/tensorrt_llm/_ipc_utils.py +++ b/tensorrt_llm/_ipc_utils.py @@ -17,15 +17,12 @@ import sys from typing import List, Tuple -from tensorrt_llm._utils import mpi_disabled - try: from cuda.bindings import driver as cuda from cuda.bindings import runtime as cudart except ImportError: from cuda import cuda, cudart -from ._utils import mpi_comm from .logger import logger from .mapping import Mapping @@ -107,15 +104,9 @@ def align_size(size, alignment): size += alignment - (size % alignment) return size - if mpi_disabled(): - from tensorrt_llm._utils import torch_comm + from tensorrt_llm._torch.distributed.communicator import Distributed - allgather = torch_comm().tp_allgather - else: - comm = mpi_comm().Split( - mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank - ) - allgather = comm.allgather + dist = Distributed.get(mapping) # see allocateIpcMemory in cpp/tensorrt_llm/runtime/ipcUtils.cpp for alignment reason # 1 << 21 is 2MB @@ -126,7 +117,7 @@ def align_size(size, alignment): _raise_if_error(cudart.cudaMemset(local_ptr, 0, aligned_size)[0]) error, local_handle = cudart.cudaIpcGetMemHandle(local_ptr) _raise_if_error(error) - handles_reserved = allgather(local_handle.reserved) + handles_reserved = dist.tp_allgather(local_handle.reserved) handles = [] for reserved in handles_reserved: diff --git a/tensorrt_llm/_torch/auto_deploy/distributed/common.py b/tensorrt_llm/_torch/auto_deploy/distributed/common.py index 27d0ffded49..9f1dc85f106 100644 --- a/tensorrt_llm/_torch/auto_deploy/distributed/common.py +++ b/tensorrt_llm/_torch/auto_deploy/distributed/common.py @@ -9,7 +9,7 @@ import torch.distributed as dist import torch.multiprocessing as mp -from tensorrt_llm._utils import get_free_port as _get_free_port +from tensorrt_llm._utils import get_free_port from ..utils.logger import ad_logger @@ -69,10 +69,6 @@ def all_gather_object(object_list, object, group=None): return dist.all_gather_object(object_list, object, group=group) -def get_free_port(): - return _get_free_port() - - def get_world_size() -> int: return dist.get_world_size() diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index 9f6b885d3b4..48066cb2565 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -47,10 +47,10 @@ ) from tensorrt_llm.llmapi.tokenizer import TokenizerBase -from ...._utils import mpi_rank, mpi_world_size +from ...._utils import get_free_port, mpi_rank, mpi_world_size from ....bindings.internal.batch_manager import CacheType from ....mapping import Mapping -from ...distributed import MPIDist +from ...distributed import Distributed from ...pyexecutor.model_engine import ModelEngine, PyTorchModelEngine from ...pyexecutor.py_executor import PyExecutor from ...pyexecutor.resource_manager import ( @@ -68,7 +68,7 @@ SimpleScheduler, ) from ..custom_ops.attention_interface import SequenceInfo -from ..distributed import common as dist +from ..distributed.common import initialize_or_skip from ..llm_args import LlmArgs from ..transform.optimizer import InferenceOptimizer from ..utils.logger import ad_logger @@ -880,7 +880,7 @@ def share_lm_head_weights_with_draft( def create_draft_model_engine_maybe( - ad_config: LlmArgs, target_engine: ADEngine, dist_mapping: Mapping, mpi_dist: MPIDist + ad_config: LlmArgs, target_engine: ADEngine, dist_mapping: Mapping, dist: Distributed ) -> Optional[PyTorchModelEngine]: """Create a draft model engine for speculative decoding. @@ -888,7 +888,7 @@ def create_draft_model_engine_maybe( ad_config: The AutoDeploy LLM configuration engine: The target model engine (ADEngine) dist_mapping: The distributed mapping configuration - mpi_dist: The MPI distribution object + dist: The distribution object Returns: PyTorchModelEngine configured as a draft model, or None if not needed @@ -925,7 +925,7 @@ def create_draft_model_engine_maybe( llm_args=draft_llm_args, mapping=dist_mapping, attn_runtime_features=attn_runtime_features, - dist=mpi_dist, + dist=dist, spec_config=draft_spec_config, is_draft_model=True, drafting_loop_wrapper=drafting_loop_wrapper, @@ -1004,14 +1004,14 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer world_size = mpi_world_size() rank = mpi_rank() dist_mapping = Mapping(rank=rank, world_size=world_size, tp_size=world_size) - mpi_dist = MPIDist(dist_mapping) + dist = Distributed.get(dist_mapping) ad_logger.set_rank(rank) torch.cuda.set_device(rank) - port = mpi_dist.broadcast(dist.get_free_port()) # use MPI broadcast to pick a free port - dist.initialize_or_skip(rank, world_size, port) + port = dist.broadcast(get_free_port()) # use MPI broadcast to pick a free port + initialize_or_skip(rank, world_size, port) # Setup AutoTuner with distributed state for allreduce autotuning - AutoTuner.get().setup_distributed_state(dist_mapping, mpi_dist) + AutoTuner.get().setup_distributed_state(dist_mapping) # some config assert ad_config.max_beam_width <= 1, "_autodeploy + beam_search is not supported" @@ -1044,7 +1044,7 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer ) draft_model_engine = create_draft_model_engine_maybe( - ad_config=ad_config, target_engine=engine, dist_mapping=dist_mapping, mpi_dist=mpi_dist + ad_config=ad_config, target_engine=engine, dist_mapping=dist_mapping, dist=dist ) spec_resource_manager = ( @@ -1171,7 +1171,7 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer scheduler, model_engine=engine, sampler=sampler, - dist=mpi_dist, + dist=dist, max_num_sequences=max_num_sequences, disable_overlap_scheduler=ad_config.disable_overlap_scheduler, max_input_len=ad_config.max_input_len, diff --git a/tensorrt_llm/_torch/autotuner.py b/tensorrt_llm/_torch/autotuner.py index 1f18270687c..c78b1a0319a 100644 --- a/tensorrt_llm/_torch/autotuner.py +++ b/tensorrt_llm/_torch/autotuner.py @@ -1072,9 +1072,7 @@ def pure_profile(stream: torch.cuda.Stream, repeat: int): stream.synchronize() if tuning_config.distributed_tuning_strategy == DistributedTuningStrategy.MERGE: # Currently only AllReduce will use this strategy, and only MPI parallel will enable tuning. - # TODO: Unified tp barrier for both MPIDist and TorchDist. - if hasattr(self._dist, "tp_comm"): - self._dist.tp_comm.barrier() + self._dist.tp_barrier() # Delay the profiled kernel launch to eliminate affects of host time overhead in profiling. if use_cuda_graph: @@ -1495,10 +1493,14 @@ def _cudaGetErrorEnum(self, error) -> str: else: raise RuntimeError("Unknown error type: {}".format(error)) - def setup_distributed_state(self, mapping: Mapping, dist: Distributed): + def setup_distributed_state(self, + mapping: Mapping, + dist: Optional[Distributed] = ...): """Setup distributed communication state for autotuning.""" self.mapping = mapping - self._dist = dist + # Create dist only when dist is not provided. + # Use the provided dist even if it is None. This is useful for testing. + self._dist = Distributed.get(mapping) if dist is ... else dist self._debug_logger( f"[AutoTuner] Whether using distributed tuning: {self._is_distributed()}" ) diff --git a/tensorrt_llm/_torch/distributed/communicator.py b/tensorrt_llm/_torch/distributed/communicator.py index 09bbc234ee2..0abad1ea496 100644 --- a/tensorrt_llm/_torch/distributed/communicator.py +++ b/tensorrt_llm/_torch/distributed/communicator.py @@ -1,8 +1,8 @@ -import copy import math import pickle # nosec B403 from abc import ABC, abstractmethod -from functools import wraps +from enum import IntEnum +from functools import lru_cache, wraps from typing import List, Optional import numpy as np @@ -31,11 +31,59 @@ from tensorrt_llm import ray_stub as ray +class ReduceOp(IntEnum): + SUM = 0 + PRODUCT = 1 + MIN = 2 + MAX = 3 + BAND = 4 + BOR = 5 + BXOR = 6 + + +_reduce_op_to_torch_dict = { + ReduceOp.SUM: torch.distributed.ReduceOp.SUM, + ReduceOp.PRODUCT: torch.distributed.ReduceOp.PRODUCT, + ReduceOp.MIN: torch.distributed.ReduceOp.MIN, + ReduceOp.MAX: torch.distributed.ReduceOp.MAX, + ReduceOp.BAND: torch.distributed.ReduceOp.BAND, + ReduceOp.BOR: torch.distributed.ReduceOp.BOR, + ReduceOp.BXOR: torch.distributed.ReduceOp.BXOR, +} + + +def reduce_op_to_torch(op: ReduceOp) -> torch.distributed.ReduceOp: + return _reduce_op_to_torch_dict[op] + + +_reduce_op_to_mpi_dict = { + ReduceOp.SUM: MPI.SUM, + ReduceOp.PRODUCT: MPI.PROD, + ReduceOp.MIN: MPI.MIN, + ReduceOp.MAX: MPI.MAX, + ReduceOp.BAND: MPI.BAND, + ReduceOp.BOR: MPI.BOR, + ReduceOp.BXOR: MPI.BXOR, +} + + +def reduce_op_to_mpi(op: ReduceOp) -> MPI.Op: + return _reduce_op_to_mpi_dict[op] + + class Distributed(ABC): def __init__(self, mapping: Mapping): self.mapping = mapping + @staticmethod + @lru_cache(maxsize=None) + def get(mapping: Mapping) -> "Distributed": + if mpi_disabled(): + return TorchDist(mapping) + else: + return MPIDist(mapping) + @property def rank(self): return self.mapping.rank @@ -108,6 +156,14 @@ def has_cp_helix(self): def cp_config(self): return self.mapping.cp_config + @abstractmethod + def barrier(self): + pass + + @abstractmethod + def tp_barrier(self): + pass + @abstractmethod def broadcast(self, obj, root=0): pass @@ -116,6 +172,10 @@ def broadcast(self, obj, root=0): def allgather(self, obj, root=0): pass + @abstractmethod + def allreduce(self, obj, op: ReduceOp = ReduceOp.SUM): + pass + @abstractmethod def tp_broadcast(self, obj, root=0, **kwargs): pass @@ -362,24 +422,9 @@ class MPIDist(Distributed): def __init__(self, mapping: Mapping): super().__init__(mapping) - self.create_cp_comm() - # Repurpose CP ranks to TP for Helix so that the right comms are created. - mapping_with_cp = None - if self.mapping.has_cp_helix(): - logger.info( - f"[MPIDist::__init__] Repurposing CP ranks to TP for Helix.") - mapping_with_cp = copy.deepcopy(self.mapping) - self.mapping = self.mapping.repurpose_helix_cp_to_tp() - - self.create_tp_comm() - self.create_pp_comm() - - # Restore the original mapping. - if mapping_with_cp is not None: - logger.info( - f"[MPIDist::__init__] Restoring original mapping undoing Helix manipulation." - ) - self.mapping = mapping_with_cp + self._cp_comm = None + self._tp_comm = None + self._pp_comm = None def broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024): comm = mpi_comm() @@ -391,6 +436,9 @@ def allgather(self, obj): def barrier(self): mpi_barrier() + def tp_barrier(self): + self.tp_comm.Barrier() + def isend(self, buf: np.ndarray, dest, tag=0): # non-blocking send numpy buffer return mpi_isend(buf, dest, tag) @@ -412,17 +460,32 @@ def isend_object(self, obj, dest, tag=0): def recv_object(self, src, tag=0): return mpi_recv_object(src, tag) - def create_tp_comm(self): - new_group = mpi_comm().group.Incl(self.mapping.tp_group) - self.tp_comm = mpi_comm().Create_group(new_group) + @property + def tp_comm(self): + if self._tp_comm is None: + mapping = self.mapping + if mapping.has_cp_helix(): + mapping = mapping.repurpose_helix_cp_to_tp() + new_group = mpi_comm().group.Incl(mapping.tp_group) + self._tp_comm = mpi_comm().Create_group(new_group) + return self._tp_comm - def create_pp_comm(self): - new_group = mpi_comm().group.Incl(self.mapping.pp_group) - self.pp_comm = mpi_comm().Create_group(new_group) + @property + def pp_comm(self): + if self._pp_comm is None: + mapping = self.mapping + if mapping.has_cp_helix(): + mapping = mapping.repurpose_helix_cp_to_tp() + new_group = mpi_comm().group.Incl(mapping.pp_group) + self._pp_comm = mpi_comm().Create_group(new_group) + return self._pp_comm - def create_cp_comm(self): - new_group = mpi_comm().group.Incl(self.mapping.cp_group) - self.cp_comm = mpi_comm().Create_group(new_group) + @property + def cp_comm(self): + if self._cp_comm is None: + new_group = mpi_comm().group.Incl(self.mapping.cp_group) + self._cp_comm = mpi_comm().Create_group(new_group) + return self._cp_comm def cp_allgather(self, obj): return self.cp_comm.allgather(obj) @@ -459,6 +522,10 @@ def pp_gather(self, obj, root=0): def pp_broadcast(self, obj, root=0): return self.pp_comm.bcast(obj, root) + def allreduce(self, obj, op: ReduceOp = ReduceOp.SUM): + reduce_op = reduce_op_to_mpi(op) + return mpi_comm().allreduce(obj, reduce_op) + class MultiHandleWrapper: """ @@ -609,6 +676,10 @@ def allgather(self, obj): def barrier(self): dist.barrier() + @log_op + def tp_barrier(self): + dist.barrier(group=self.mapping.tp_group_pg) + @log_op def isend(self, buf: np.ndarray, dest, tag=0): # non-blocking send numpy buffer @@ -672,14 +743,16 @@ def isend_object(self, obj, dest, tag=0): return MultiHandleWrapper(works) @log_op - def allreduce(self, - obj: int | float | torch.Tensor, - op=torch.distributed.ReduceOp.SUM): + def allreduce( + self, + obj: int | float | torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + ): is_base_type = isinstance(obj, int) or isinstance(obj, float) if is_base_type: obj = torch.tensor(obj) - dist.all_reduce(obj, op=op) + dist.all_reduce(obj, op=reduce_op_to_torch(op)) if is_base_type: obj = obj.item() diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index 852f2e063da..30fdaff3914 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -10,7 +10,7 @@ from tensorrt_llm.mapping import Mapping from ...inputs.multimodal import MultimodalParams -from ..distributed import MPIDist +from ..distributed import Distributed from ..expert_statistic import ExpertStatistic from ..memory_buffer_utils import get_memory_buffers from ..modules.multi_stream_utils import with_multi_stream @@ -75,7 +75,7 @@ class CUDAGraphRunnerConfig: enable_attention_dp: bool batch_size: int mapping: Optional[Mapping] - dist: Optional[MPIDist] + dist: Optional[Distributed] kv_cache_manager_key: Any sparse_attention_config: Optional[BaseSparseAttentionConfig] = None diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index f186da6cd89..fdfb511e5e0 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -35,7 +35,7 @@ from ..autotuner import AutoTuner, autotune from ..compilation.backend import Backend from ..compilation.utils import capture_piecewise_cuda_graph -from ..distributed import MPIDist +from ..distributed import Distributed from ..distributed.communicator import init_pp_comm from ..expert_statistic import ExpertStatistic from ..memory_buffer_utils import with_shared_pool @@ -134,7 +134,7 @@ def __init__( llm_args: TorchLlmArgs, mapping: Optional[Mapping] = None, attn_runtime_features: Optional[AttentionRuntimeFeatures] = None, - dist: Optional[MPIDist] = None, + dist: Optional[Distributed] = None, spec_config: Optional["DecodingBaseConfig"] = None, is_draft_model: bool = False, drafting_loop_wrapper: Optional[Callable[[torch.nn.Module], diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index bd1857dda27..93cf379235f 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -13,7 +13,7 @@ import tensorrt_llm from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType -from tensorrt_llm._utils import get_sm_version, mpi_disabled +from tensorrt_llm._utils import get_sm_version from tensorrt_llm.llmapi.llm_args import (CapacitySchedulerPolicy, ContextChunkingPolicy, GuidedDecodingConfig, LoadFormat, @@ -27,7 +27,7 @@ from ..attention_backend.interface import AttentionRuntimeFeatures from ..attention_backend.trtllm import TrtllmAttention -from ..distributed import MPIDist, TorchDist +from ..distributed import Distributed from ..speculative import (get_num_extra_kv_tokens, get_spec_drafter, get_spec_resource_manager) from ..virtual_memory import ExecutorMemoryType, RestoreMode @@ -303,10 +303,7 @@ def create_py_executor( "when only processing vision encoder inputs.") mapping = _get_mapping(llm_args.parallel_config.to_mapping()) - if mpi_disabled(): - dist = TorchDist(mapping=mapping) - else: - dist = MPIDist(mapping=mapping) + dist = Distributed.get(mapping) vm_pools = {} enable_sleep = llm_args.enable_sleep diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 619f8525c17..278af63b233 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -10,8 +10,7 @@ import tensorrt_llm import tensorrt_llm.bindings -from tensorrt_llm._utils import mpi_disabled -from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE +from tensorrt_llm._torch.distributed.communicator import Distributed, ReduceOp from tensorrt_llm.llmapi.llm_args import (KvCacheConfig, PeftCacheConfig, PybindMirror) from tensorrt_llm.lora_helper import LoraConfig @@ -28,11 +27,6 @@ get_draft_token_length) from .scheduler import ScheduledRequests -if ENABLE_MULTI_DEVICE: - from mpi4py import MPI - - from tensorrt_llm._utils import mpi_comm - BufferManagerCpp = tensorrt_llm.bindings.internal.runtime.BufferManager KVCacheManagerCpp = tensorrt_llm.bindings.internal.batch_manager.KVCacheManager CacheTypeCpp = tensorrt_llm.bindings.internal.batch_manager.CacheType @@ -803,12 +797,11 @@ def calculate_max_num_blocks(self, if mapping.world_size > 1: # make sure all ranks use same value for maxTokens - if mpi_disabled(): - from tensorrt_llm._utils import torch_comm - max_tokens = torch_comm().allreduce( - max_tokens, op=torch.distributed.ReduceOp.MIN) - else: - max_tokens = mpi_comm().allreduce(max_tokens, op=MPI.MIN) + dist = Distributed.get(mapping) + max_tokens = dist.allreduce( + max_tokens, + op=ReduceOp.MIN, + ) # get number of blocks blocks_in_primary_pool = int(max_tokens // tokens_per_block) diff --git a/tests/microbenchmarks/all_reduce.py b/tests/microbenchmarks/all_reduce.py index d2a9adf453c..9f50b66d3b4 100644 --- a/tests/microbenchmarks/all_reduce.py +++ b/tests/microbenchmarks/all_reduce.py @@ -29,10 +29,10 @@ from tensorrt_llm import Mapping from tensorrt_llm._torch.autotuner import AutoTuner, autotune from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp, - MPIDist, TorchDist) + Distributed) from tensorrt_llm._torch.modules.rms_norm import RMSNorm from tensorrt_llm._utils import (get_sm_version, local_mpi_rank, local_mpi_size, - mpi_disabled, nvtx_range) + nvtx_range) from tensorrt_llm.bindings.internal.runtime import delay_kernel from tensorrt_llm.functional import AllReduceParams, AllReduceStrategy from tensorrt_llm.logger import logger @@ -41,7 +41,7 @@ def profile_allreduce( mapping: Mapping, - dist: TorchDist | MPIDist, + dist: Distributed, enable_cudagraph: bool = False, inner_loop=200, outer_loop=10, @@ -137,14 +137,10 @@ def allreduce_benchmark( cudart.cudaSetDevice(local_rank) mapping = Mapping(world_size, rank, gpus_per_node, tp_size=world_size) - if mpi_disabled(): - dist = TorchDist(mapping=mapping) - else: - dist = MPIDist(mapping=mapping) logger.set_rank(mapping.rank) - AutoTuner.get().setup_distributed_state(mapping, dist) + AutoTuner.get().setup_distributed_state(mapping) sm_version = get_sm_version() diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py index f9595cde7f0..669833153af 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py @@ -6,10 +6,11 @@ from torch.export import export from tensorrt_llm._torch.auto_deploy.custom_ops.trtllm_dist import is_trtllm_op_available -from tensorrt_llm._torch.auto_deploy.distributed import common as dist +from tensorrt_llm._torch.auto_deploy.distributed.common import initialize_or_skip from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op +from tensorrt_llm._utils import get_free_port from tensorrt_llm.llmapi.mpi_session import MpiPoolSession # needed since MPI executor pool leaks a thread (_manager_spawn) on shutdown @@ -66,7 +67,7 @@ def _test_allreduce_fusion(port: int, ModuleCls, strategy: str): if not is_trtllm_op_available(): pytest.skip("Require trtllm ops to run test_allreduce_fusion.") - _, _ = dist.initialize_or_skip(port=port) + _, _ = initialize_or_skip(port=port) # Testing tensors dtype = torch.float16 @@ -143,7 +144,7 @@ def _test_allreduce_fusion(port: int, ModuleCls, strategy: str): def test_allreduce_fusion(device_count, ModuleCls, strategy): if device_count <= 1: pytest.skip("Require multi GPUs to run test_allreduce_fusion.") - port = dist.get_free_port() + port = get_free_port() n_workers = device_count mpi_pool = MpiPoolSession(n_workers=n_workers) diff --git a/tests/unittest/_torch/misc/test_autotuner.py b/tests/unittest/_torch/misc/test_autotuner.py index 11b5ce114cd..7212f491e6b 100644 --- a/tests/unittest/_torch/misc/test_autotuner.py +++ b/tests/unittest/_torch/misc/test_autotuner.py @@ -17,10 +17,9 @@ FakeTensor, OptimizationProfile, StaticDim, TunableRunner, TuningConfig, autotune) -from tensorrt_llm._torch.distributed.communicator import MPIDist, TorchDist +from tensorrt_llm._torch.distributed import Distributed from tensorrt_llm._torch.utils import (get_power_of_2_num_tokens_buckets, next_positive_power_of_2) -from tensorrt_llm._utils import mpi_disabled from tensorrt_llm.bindings.internal.runtime import delay_kernel from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping @@ -720,14 +719,11 @@ def _distributed_worker_function(world_size, strategy): rank=rank, tp_size=world_size, pp_size=1) - if mpi_disabled(): - dist = TorchDist(mapping=mapping) - else: - dist = MPIDist(mapping=mapping) + dist = Distributed.get(mapping) tuner = AutoTuner.get() tuner.clear_cache() - tuner.setup_distributed_state(mapping, dist) + tuner.setup_distributed_state(mapping) x = torch.randn(16, 32, device='cuda') w = torch.randn(32, 64, device='cuda') diff --git a/tests/unittest/_torch/modules/test_fused_moe.py b/tests/unittest/_torch/modules/test_fused_moe.py index 29be45b8fd2..773218a40c2 100644 --- a/tests/unittest/_torch/modules/test_fused_moe.py +++ b/tests/unittest/_torch/modules/test_fused_moe.py @@ -24,7 +24,6 @@ skip_pre_hopper) from tensorrt_llm._torch.autotuner import AutoTuner, autotune -from tensorrt_llm._torch.distributed import MPIDist, TorchDist from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.modules.fused_moe.fused_moe_cute_dsl import \ CuteDslFusedMoE @@ -45,7 +44,7 @@ from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import \ IS_TRITON_KERNELS_AVAILABLE from tensorrt_llm._torch.modules.gated_mlp import GatedMLP -from tensorrt_llm._utils import get_sm_version, mpi_disabled, mpi_rank +from tensorrt_llm._utils import get_sm_version, mpi_rank from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig @@ -105,12 +104,7 @@ def test_fused_moe(moe_backend, mapping = mapping or Mapping() mapping.rank = mpi_rank() - if mpi_disabled(): - dist = TorchDist(mapping=mapping) - else: - dist = MPIDist(mapping=mapping) - - AutoTuner.get().setup_distributed_state(mapping, dist) + AutoTuner.get().setup_distributed_state(mapping) torch.cuda.set_device(mapping.rank) diff --git a/tests/unittest/_torch/multi_gpu/test_allreduce.py b/tests/unittest/_torch/multi_gpu/test_allreduce.py index b531d826435..2ce869c13b7 100644 --- a/tests/unittest/_torch/multi_gpu/test_allreduce.py +++ b/tests/unittest/_torch/multi_gpu/test_allreduce.py @@ -27,11 +27,9 @@ from tensorrt_llm._torch.autotuner import AutoTuner, autotune from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp, AllReduceParams, AllReduceStrategy, - MoEAllReduce, MoEAllReduceParams, - MPIDist, TorchDist) + MoEAllReduce, MoEAllReduceParams) from tensorrt_llm._torch.modules.linear import Linear, TensorParallelMode from tensorrt_llm._torch.modules.rms_norm import RMSNorm -from tensorrt_llm._utils import mpi_disabled from tensorrt_llm.mapping import Mapping sys.path.append(os.path.join(os.path.dirname(__file__), "..")) @@ -133,12 +131,8 @@ def e2m1_and_ufp8sf_scale_to_float_v2(e2m1_tensor, tp_size=tensor_parallel_size, rank=tensor_parallel_rank, ) - if mpi_disabled(): - dist = TorchDist(mapping=mapping) - else: - dist = MPIDist(mapping=mapping) - AutoTuner.get().setup_distributed_state(mapping, dist) + AutoTuner.get().setup_distributed_state(mapping) linear = Linear( in_features=hidden_size, out_features=hidden_size, diff --git a/tests/unittest/others/test_kv_cache_transceiver.py b/tests/unittest/others/test_kv_cache_transceiver.py index a8afc3f2cfb..7638f9c8057 100644 --- a/tests/unittest/others/test_kv_cache_transceiver.py +++ b/tests/unittest/others/test_kv_cache_transceiver.py @@ -6,7 +6,7 @@ import tensorrt_llm import tensorrt_llm.bindings import tensorrt_llm.bindings.executor as trtllm -from tensorrt_llm._torch.distributed import MPIDist +from tensorrt_llm._torch.distributed import Distributed from tensorrt_llm._torch.pyexecutor.kv_cache_transceiver import \ create_kv_cache_transceiver from tensorrt_llm._torch.pyexecutor.llm_request import (LlmRequest, @@ -79,7 +79,7 @@ def test_kv_cache_transceiver_single_process(ctx_gen_kv_cache_dtype, cache_transceiver_config = CacheTransceiverConfig(backend=backend, max_tokens_in_buffer=512) - dist = MPIDist(mapping=mapping) + dist = Distributed.get(mapping) kv_cache_transceiver_ctx = create_kv_cache_transceiver( mapping, dist, kv_cache_manager_ctx, attention_type, cache_transceiver_config) @@ -139,7 +139,7 @@ def test_kv_cache_transceiver_single_process(ctx_gen_kv_cache_dtype, def test_cancel_request_in_transmission(attention_type): # Init kv_cache manager and cache transceiver mapping = Mapping(world_size=1, rank=0) - dist = MPIDist(mapping=mapping) + dist = Distributed.get(mapping) ctx_kv_cache_dtype, gen_kv_cache_dtype = DataType.HALF, DataType.HALF kv_cache_manager_ctx = create_kv_cache_manager(mapping, ctx_kv_cache_dtype) kv_cache_manager_gen = create_kv_cache_manager(mapping, gen_kv_cache_dtype)