diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/README.md b/tensorrt_llm/_torch/auto_deploy/custom_ops/README.md index 6bef175199b..186123ba0fe 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/README.md +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/README.md @@ -18,8 +18,8 @@ The table below lists the operators ordered by their backend. | `torch.ops.auto_deploy.torch_attention_grouped_sdpa` | Grouped SDPA implementation | | `torch.ops.auto_deploy.torch_attention_repeat_kv` | KV repetition for attention | | `torch.ops.auto_deploy.torch_attention_sdpa` | Standard SDPA implementation | -| `torch.ops.auto_deploy.torch_dist_all_gather` | Distributed all-gather operation | -| `torch.ops.auto_deploy.torch_dist_all_reduce` | Distributed all-reduce operation | +| `torch.ops.auto_deploy.torch_all_gather` | Distributed all-gather operation | +| `torch.ops.auto_deploy.torch_all_reduce` | Distributed all-reduce operation | | `torch.ops.auto_deploy.torch_linear_simple` | Simple linear layer implementation | | `torch.ops.auto_deploy.torch_moe` | Mixture of Experts implementation | | `torch.ops.auto_deploy.torch_moe_fused` | Fused Mixture of Experts implementation | @@ -39,4 +39,4 @@ The table below lists the operators ordered by their backend. | `torch.ops.auto_deploy.triton_rope_on_flattened_inputs` | Triton RoPE on flattened inputs | | `torch.ops.auto_deploy.triton_rope_with_input_pos` | Triton RoPE with input positions | | `torch.ops.auto_deploy.trtllm_moe_fused` | TensorRT-LLM fused MoE implementation | -| `torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce` | TensorRT-LLM fused linear layer followed by all-reduce operation | +| `torch.ops.auto_deploy.torch_fused_linear_all_reduce` | fused linear layer followed by all-reduce operation | diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py index 23a80b94d74..2fb867d8c79 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py @@ -1,17 +1,16 @@ -"""Custom ops and make sure they are all registered.""" +"""AutoDeploy's custom ops library. -from ._triton_attention_internal import * -from .dist import * -from .flashinfer_attention import * -from .flashinfer_rope import * -from .linear import * -from .mla import * -from .quant import * -from .rms_norm import * -from .torch_attention import * -from .torch_backend_attention import * -from .torch_moe import * -from .torch_rope import * -from .triton_attention import * -from .triton_rope import * -from .trtllm_moe import * +This file ensures that all publicly listed files/custom ops in the custom_ops folder are +auto-imported and the corresponding custom ops are registered. +""" + +import importlib +import pkgutil + +__all__ = [] + +for _, module_name, is_pkg in pkgutil.iter_modules(__path__): + if module_name.startswith("_"): + continue + __all__.append(module_name) + importlib.import_module(f"{__name__}.{module_name}") diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/linear.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/linear.py index fda48e4ba57..214626ad24e 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/linear.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/linear.py @@ -4,9 +4,6 @@ import torch -from ..distributed import common as dist -from ..distributed import trtllm as trtllm_dist - @torch.library.custom_op("auto_deploy::torch_linear_simple", mutates_args=()) def simple(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: @@ -24,26 +21,4 @@ def simple(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tenso @simple.register_fake def simple_fake(input, weight, bias): """Fake implementation of simple_linear.""" - # return torch.empty( - # input.shape[:-1] + (weight.shape[-1],), dtype=input.dtype, device=input.device - # ) - return torch.ops.aten.linear(input, weight, bias) - - -@torch.library.custom_op( - "auto_deploy::trtllm_dist_fused_linear_all_reduce", mutates_args=(), device_types="cuda" -) -def fused_linear_all_reduce( - input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] -) -> torch.Tensor: - """Fused linear followed by all_reduce on the output.""" - output = torch.ops.aten.linear(input, weight, bias) - if trtllm_dist.is_trtllm_op_available(): - return trtllm_dist.trtllm_allreduce(output, op=dist.ReduceOp.SUM) - dist.all_reduce(output, op=dist.ReduceOp.SUM) - return output - - -@fused_linear_all_reduce.register_fake -def fused_linear_all_reduce_fake(input, weight, bias): return torch.ops.aten.linear(input, weight, bias) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py index d17b816e825..6c322ae8eba 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py @@ -10,7 +10,6 @@ from tensorrt_llm._torch.autotuner import autotune from ..distributed import common as dist -from ..distributed import trtllm as trtllm_dist from .torch_libs.float8_python_api import addmm_float8_unwrapped TRTLLM_FP4_OP_AVAILABLE = True @@ -119,8 +118,6 @@ def fused_fp8_linear_all_reduce( out = torch.ops.auto_deploy.torch_quant_fp8_linear( input, weight_fp8, bias, input_scale, weight_scale ) - if trtllm_dist.is_trtllm_op_available(): - return trtllm_dist.trtllm_allreduce(out, op=dist.ReduceOp.SUM) dist.all_reduce(out, op=dist.ReduceOp.SUM) return out diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/dist.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_dist.py similarity index 59% rename from tensorrt_llm/_torch/auto_deploy/custom_ops/dist.py rename to tensorrt_llm/_torch/auto_deploy/custom_ops/torch_dist.py index d6f13fbedd7..faf38e6c984 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/dist.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_dist.py @@ -5,16 +5,13 @@ import torch from ..distributed import common as dist -from ..distributed import trtllm as trtllm_dist -@torch.library.custom_op("auto_deploy::torch_dist_all_gather", mutates_args=(), device_types="cuda") +@torch.library.custom_op("auto_deploy::torch_all_gather", mutates_args=(), device_types="cuda") def all_gather( tensor: torch.Tensor, dim: int = 0, sizes: Optional[List[int]] = None ) -> torch.Tensor: """All gather followed by concat in dim = 0. This is the default nccl behavior.""" - if trtllm_dist.is_trtllm_op_available(): - return trtllm_dist.trtllm_allgather(tensor, dim=dim, sizes=sizes) tl = [torch.zeros_like(tensor) for _ in range(dist.get_world_size())] dist.all_gather(tl, tensor) return torch.cat(tl, dim=dim) @@ -25,15 +22,13 @@ def all_gather_fake(tensor, dim=0): return torch.cat([torch.empty_like(tensor) for _ in range(dist.get_world_size())], dim=dim) -@torch.library.custom_op("auto_deploy::torch_dist_all_reduce", mutates_args=(), device_types="cuda") +@torch.library.custom_op("auto_deploy::torch_all_reduce", mutates_args=(), device_types="cuda") def all_reduce(t: torch.Tensor) -> torch.Tensor: """All_reduce across the ranks. Reduction op is SUM. NOTE: this op requires an extra memory copy and should ONLY be used for debugging + testing. For efficient all_reduce ops one should write/replace it with a fused op. """ - if trtllm_dist.is_trtllm_op_available(): - return trtllm_dist.trtllm_allreduce(t, op=dist.ReduceOp.SUM) t_res = t.clone() dist.all_reduce(t_res, op=dist.ReduceOp.SUM) return t_res @@ -42,3 +37,20 @@ def all_reduce(t: torch.Tensor) -> torch.Tensor: @all_reduce.register_fake def all_reduce_fake(tensor): return torch.empty_like(tensor) + + +@torch.library.custom_op( + "auto_deploy::torch_fused_linear_all_reduce", mutates_args=(), device_types="cuda" +) +def fused_linear_all_reduce( + input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] +) -> torch.Tensor: + """Fused linear followed by in-place all_reduce on the output.""" + output = torch.ops.aten.linear(input, weight, bias) + dist.all_reduce(output, op=dist.ReduceOp.SUM) + return output + + +@fused_linear_all_reduce.register_fake +def fused_linear_all_reduce_fake(input, weight, bias): + return torch.ops.aten.linear(input, weight, bias) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/_triton_attention_internal.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention_internal.py similarity index 100% rename from tensorrt_llm/_torch/auto_deploy/custom_ops/_triton_attention_internal.py rename to tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention_internal.py diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_dist.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_dist.py new file mode 100644 index 00000000000..870fe92520c --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_dist.py @@ -0,0 +1,85 @@ +"""TRT-LLM optimized dist ops.""" + +from typing import List, Optional + +import torch + +from ....functional import AllReduceParams +from ....mapping import Mapping +from ...distributed import AllReduce, allgather +from ...modules.linear import AllReduceFusionOp, AllReduceStrategy +from ..distributed.common import get_rank_world_size + + +@torch.library.custom_op("auto_deploy::trtllm_all_gather", mutates_args=(), device_types="cuda") +def all_gather( + tensor: torch.Tensor, dim: int = 0, sizes: Optional[List[int]] = None +) -> torch.Tensor: + """TRT-LLM all gather followed by concat in dim = 0.""" + rank, world_size = get_rank_world_size() + p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank) + result = allgather(tensor, p_config, dim=dim, sizes=sizes) + assert isinstance(result, torch.Tensor), "Expected tensor result from allgather" + return result + + +@all_gather.register_fake +def all_gather_fake(tensor, dim=0, sizes=None): + rank, world_size = get_rank_world_size() + return torch.cat([torch.empty_like(tensor) for _ in range(world_size)], dim=dim) + + +@torch.library.custom_op("auto_deploy::trtllm_all_reduce", mutates_args=(), device_types="cuda") +def all_reduce(tensor: torch.Tensor, strategy: int = int(AllReduceStrategy.AUTO)) -> torch.Tensor: + """TRT-LLM all_reduce across the ranks.""" + rank, world_size = get_rank_world_size() + p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank) + torch_op = AllReduce(mapping=p_config, strategy=AllReduceStrategy(strategy)) + result = torch_op(tensor) + assert isinstance(result, torch.Tensor), "Expected tensor result from allreduce" + return result + + +@all_reduce.register_fake +def all_reduce_fake(tensor): + return torch.empty_like(tensor) + + +@torch.library.custom_op( + "auto_deploy::trtllm_fused_allreduce_residual_rmsnorm", mutates_args=(), device_types="cuda" +) +def trtllm_fused_allreduce_residual_rmsnorm( + tensor: torch.Tensor, + residual: torch.Tensor, + norm_weight: torch.Tensor, + eps: float, + strategy: int = int(AllReduceStrategy.AUTO), +) -> tuple[torch.Tensor, torch.Tensor]: + """Fusing allreduce, residual (add), and hf_rms_norm together.""" + rank, world_size = get_rank_world_size() + p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank) + + # Use AllReduceParams like the old implementation + all_reduce_params = AllReduceParams( + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, + bias=None, + residual=residual, + norm_weight=norm_weight, + eps=eps, + ) + + torch_op = AllReduce(mapping=p_config, strategy=AllReduceStrategy(strategy)) + output = torch_op(tensor, all_reduce_params=all_reduce_params) + assert len(output) == 2, "Expected 2 outputs from trtllm_fused_allreduce_residual_rmsnorm" + return output[0], output[1] + + +@trtllm_fused_allreduce_residual_rmsnorm.register_fake +def trtllm_fused_allreduce_residual_rmsnorm_fake( + tensor: torch.Tensor, + residual: torch.Tensor, + norm_weight: torch.Tensor, + eps: float, + strategy: int = int(AllReduceStrategy.AUTO), +) -> tuple[torch.Tensor, torch.Tensor]: + return torch.empty_like(tensor), torch.empty_like(tensor) diff --git a/tensorrt_llm/_torch/auto_deploy/distributed/__init__.py b/tensorrt_llm/_torch/auto_deploy/distributed/__init__.py index e69de29bb2d..55e5f844b4a 100644 --- a/tensorrt_llm/_torch/auto_deploy/distributed/__init__.py +++ b/tensorrt_llm/_torch/auto_deploy/distributed/__init__.py @@ -0,0 +1 @@ +from .common import * diff --git a/tensorrt_llm/_torch/auto_deploy/distributed/common.py b/tensorrt_llm/_torch/auto_deploy/distributed/common.py index 2a976280e7a..1fe49f7c6ac 100644 --- a/tensorrt_llm/_torch/auto_deploy/distributed/common.py +++ b/tensorrt_llm/_torch/auto_deploy/distributed/common.py @@ -1,4 +1,15 @@ -"""Common utilities for distributed inference.""" +"""Common utilities for distributed inference. + +These utilities serve two purposes: + +1. Drop-in replacement for torch.distributed to ensure that any function in torch.distributed works + out of the box. +2. Provide a simple interface to spawn multiple processes and communicate with them. We support + three supports: + a. MPI initialized via TRT-LLM runtime or mpirun. + b. TorchElastic initialized via torchrun. + c. Simple python multiprocessing started explicitly. +""" import atexit import os @@ -12,8 +23,6 @@ from ..utils.logger import ad_logger -# TODO: check to what extend we can reuse _torch/distributed.py - class _DistGroup: """Global instance to set/get the default process group for distributed ops.""" @@ -93,17 +102,22 @@ def initialize_or_skip(*args, **kwargs) -> Tuple[int, int]: return get_rank(), get_world_size() -def is_ompi(): +def is_ompi() -> bool: """Check whether multi-processing was initialized with explicitly calling mpirun.""" return "OMPI_COMM_WORLD_SIZE" in os.environ -def is_torchelastic(): +def is_trtllm_dist_available() -> bool: + """Check if TRT-LLM distributed operations are available (they only work with MPI!).""" + return is_ompi() + + +def is_torchelastic() -> bool: """Check whether multi-processing was initialized with torchelastic.""" return "TORCHELASTIC_RUN_ID" in os.environ -def cleanup(): +def cleanup() -> None: """Destroy process group when the program exits.""" if dist.is_initialized(): ad_logger.info("Destroying process group") diff --git a/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py b/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py deleted file mode 100644 index e42da002f6d..00000000000 --- a/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py +++ /dev/null @@ -1,59 +0,0 @@ -import torch - -from .common import ReduceOp, get_rank_world_size, is_ompi - -# use trtllm distributed ops to improve TP performance if possible -try: - from ....mapping import Mapping - from ...distributed import AllReduce, allgather - from ...modules.linear import AllReduceFusionOp, AllReduceParams, AllReduceStrategy - - def trtllm_allgather(tensor, dim, sizes=None): - rank, world_size = get_rank_world_size() - p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank) - return allgather(tensor, p_config, dim=dim, sizes=sizes) - - def trtllm_allreduce(tensor, op, all_reduce_params=None): - rank, world_size = get_rank_world_size() - assert op == ReduceOp.SUM, "TRT-LLM all reduce only supports SUM op." - p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank) - torch_op = AllReduce(mapping=p_config, strategy=AllReduceStrategy.AUTO) - return torch_op(tensor, all_reduce_params=all_reduce_params) - - @torch.library.custom_op( - "dist::fused_allreduce_residual_rmsnorm", mutates_args=(), device_types="cuda" - ) - def fused_allreduce_residual_rmsnorm( - tensor: torch.Tensor, residual: torch.Tensor, norm_weight: torch.Tensor, eps: float - ) -> tuple[torch.Tensor, torch.Tensor]: - """Fusing allreduce, residual (add), and hf_rms_norm together.""" - all_reduce_params = AllReduceParams( - fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, - bias=None, - residual=residual, - norm_weight=norm_weight, - eps=eps, - ) - return trtllm_allreduce(tensor, ReduceOp.SUM, all_reduce_params=all_reduce_params) - - @fused_allreduce_residual_rmsnorm.register_fake - def fused_allreduce_residual_rmsnorm_fake( - tensor: torch.Tensor, residual: torch.Tensor, norm_weight: torch.Tensor, eps: float - ) -> tuple[torch.Tensor, torch.Tensor]: - return torch.empty_like(tensor), torch.empty_like(tensor) - - TRTLLM_OP_AVAILABLE = True -except ImportError: - - def trtllm_allgather(tensor, dim, sizes=None): - raise ImportError("TRT-LLM is not available.") - - def trtllm_allreduce(tensor, op): - raise ImportError("TRT-LLM is not available.") - - TRTLLM_OP_AVAILABLE = False - - -def is_trtllm_op_available(): - # TRT-LLM only work with MPI - return TRTLLM_OP_AVAILABLE and is_ompi() diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/collectives.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/collectives.py index 8cec047561f..b5319ce37d1 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/collectives.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/library/collectives.py @@ -3,33 +3,38 @@ import torch from torch.fx import GraphModule -from ...distributed.trtllm import is_trtllm_op_available from ...utils.logger import ad_logger from ...utils.node_utils import get_op_overload_packet, get_user_if_pattern_match, is_op from .._graph import canonicalize_graph -# TODO: This is an overly simplified model that works well for vanilla Llama models. -# However, we eventually want to consider more sophisticated patterns such as -# * all_reduce(lin1(x) + lin2(x)) -# * version above with fused GEMMs (i.e. with a split node) -# * all_reduce(pointwise_op(linear(x))) -# * ... -def fuse_collectives(gm: GraphModule) -> None: +# TODO (lucaslie): reconsider the value-add of this transform. Here are some thoughts: +# 1. This is only useful for using torch's built-in, in-place allreduce op. TRT-LLM's allreduce op +# is not in-place anyway (hence can be expressed as a custom op) and supposedly more performant. +# 2. Peak performance can only be achieved with fusing the collective with other ops like RMS or +# similar pointwise ops. None of this can be done with the stock torch collectives. +# 3. Hence I wonder if there is any value in transform only given us a small perf boost but not +# really taking us to peak perf anyway... +def fuse_torch_allreduce(gm: GraphModule) -> None: + """Replace linear followed by torch's allreduce with a fused op. + + Torch's collectives are in-place operations, which we cannot express as single custom op. Hence, + to take advantage of the memory savings we fuse them with the preceding linear op. + """ num_gemm_collective_fusions = 0 ad_logger.debug("Before GEMM+Collective fusion: " + str(gm)) # lookup for fused ops # TODO: avoid this hardcoded lookup, e.g., by generating fused ops on the fly. lookup = { - torch.ops.auto_deploy.torch_linear_simple: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce, - torch.ops.aten.linear: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce, + torch.ops.auto_deploy.torch_linear_simple: torch.ops.auto_deploy.torch_fused_linear_all_reduce, + torch.ops.aten.linear: torch.ops.auto_deploy.torch_fused_linear_all_reduce, torch.ops.auto_deploy.torch_quant_fp8_linear: torch.ops.auto_deploy.torch_quant_fused_fp8_linear_all_reduce, } # go through all nodes and find all_reduce nodes for node in gm.graph.nodes: - if not is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce): + if not is_op(node, torch.ops.auto_deploy.torch_all_reduce): continue # check if args are as expected @@ -63,16 +68,13 @@ def fuse_allreduce_residual_rmsnorm(gm: GraphModule) -> None: """Essentially, this function fuses the following operators into one allreduce trtllm implementation. * target pattern: - x = all_reduce(x) + x = all_reduce(x, strategy) y = x + residual return rmsnorm(y), y * replacement: - fused_allreduce_residual_rmsnorm(x, residual, rmsnorm_weight, rmsnorm_eps) + trtllm_fused_allreduce_residual_rmsnorm(x, residual, rmsnorm_weight, rmsnorm_eps, strategy) """ - if not is_trtllm_op_available(): - return - num_ar_r_rms_fusions = 0 ad_logger.debug("Before allreduce+residual+rmsnorm fusion: " + str(gm)) @@ -125,16 +127,15 @@ def trace_and_fuse(allreduce_node, graph): ) eps = add_eps_node.args[1] + all_args = (tensor, residual, norm_weight, eps) + if len(allreduce_node.args) > 1: + all_args += (allreduce_node.args[1],) # strategy is the second arg + # Insert nodes with graph.inserting_before(allreduce_node): fused_node = graph.call_function( - torch.ops.dist.fused_allreduce_residual_rmsnorm, - args=( - tensor, - residual, - norm_weight, - eps, - ), + torch.ops.auto_deploy.trtllm_fused_allreduce_residual_rmsnorm, + args=all_args, ) # Extract outputs from the tuple returned by `fused_node` final_output_node = gm.graph.create_node( @@ -159,7 +160,7 @@ def trace_and_fuse(allreduce_node, graph): # Traverse all nodes for node in gm.graph.nodes: - if is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce): + if is_op(node, torch.ops.auto_deploy.trtllm_all_reduce): trace_and_fuse(allreduce_node=node, graph=gm.graph) canonicalize_graph(gm) diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/sharding.py index d7ed5918a49..c7b025957d1 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/library/sharding.py @@ -20,15 +20,18 @@ import operator from abc import ABC, abstractmethod from collections import defaultdict -from enum import IntEnum +from enum import IntEnum, StrEnum from functools import partial -from typing import Callable, DefaultDict, Dict, List, Literal, Optional, Set +from typing import Callable, DefaultDict, Dict, List, Literal, Optional, Set, Tuple, overload import torch import torch.nn as nn from pydantic import BaseModel, ConfigDict, Field +from torch._ops import OpOverloadPacket from torch.fx import GraphModule, Node +from .....functional import AllReduceStrategy +from ... import distributed as dist from ...utils.logger import ad_logger from ...utils.node_utils import ( extract_param_names_from_lin_node, @@ -48,6 +51,14 @@ class SplitDimension(IntEnum): COLUMN = 1 # Split along columns (second dimension) +class DistBackend(StrEnum): + """Enum for distributed backends.""" + + AUTO = "auto" # pick trtllm if available, otherwise torch + TORCH = "torch" + TRTLLM = "trtllm" + + class ShardingTransformInfo(BaseModel, ABC): """Abstract base class for transformation configurations.""" @@ -56,6 +67,8 @@ class ShardingTransformInfo(BaseModel, ABC): target_node: str rank: int world_size: int + dist_backend: DistBackend = DistBackend.AUTO + trtllm_allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO def validate(self, gm: GraphModule = None, node: Node = None) -> bool: """ @@ -79,6 +92,34 @@ def check_and_apply(self, gm: GraphModule, node: Node) -> None: return self.apply(gm, node) + @overload + def _get_dist_op(self, dist_op: None) -> Tuple[None, Tuple]: ... + + @overload + def _get_dist_op(self, dist_op: str) -> Tuple[OpOverloadPacket, Tuple]: ... + + def _get_dist_op(self, dist_op: Optional[str]) -> Tuple[Optional[OpOverloadPacket], Tuple]: + """Get the dist op and extra final args for the dist op.""" + if dist_op is None: + return None, () + # construct dist op lookup + strategy = int(self.trtllm_allreduce_strategy) + dist_lookup = { + DistBackend.TORCH: { + "all_reduce": (torch.ops.auto_deploy.torch_all_reduce, ()), + "all_gather": (torch.ops.auto_deploy.torch_all_gather, ()), + }, + DistBackend.TRTLLM: { + "all_reduce": (torch.ops.auto_deploy.trtllm_all_reduce, (strategy,)), + "all_gather": (torch.ops.auto_deploy.trtllm_all_gather, ()), + }, + } + auto_backend = DistBackend.TRTLLM if dist.is_trtllm_dist_available() else DistBackend.TORCH + dist_lookup[DistBackend.AUTO] = dist_lookup[auto_backend] + + dist_op, final_args = dist_lookup[self.dist_backend][dist_op] + return dist_op, final_args # type: ignore + class TPShardingInfo(ShardingTransformInfo): """Configuration for TP sharding transformations.""" @@ -107,13 +148,18 @@ def validate(self, gm: GraphModule = None, node: Node = None) -> bool: def apply(self, gm: GraphModule, node: Node) -> None: """Apply TP sharding transformation to the graph module.""" + dist_op, dist_final_args = self._get_dist_op(self.dist_op) + if self.dist_op == "all_gather": + dist_final_args = (-1,) + dist_final_args + _insert_sharded_matmul( gm=gm, node=node, dim=self.split_dim.value, rank=self.rank, world_size=self.world_size, - add_dist=self.dist_op is not None, + dist_op=dist_op, + dist_extra_args=dist_final_args, min_local_shape=self.min_local_shape, ) @@ -146,7 +192,7 @@ def validate(self, gm: GraphModule = None, node: Node = None) -> bool: # Check if the distribution is balanced remainder = bmm_batch_size % self.world_size - # NOTE: our torch.ops.auto_deploy.torch_dist_all_gather doesn't support uneven splits at the moment. + # NOTE: our torch.ops.auto_deploy.torch_all_gather doesn't support uneven splits at the moment. if remainder: ad_logger.warning( f"BMM batch size {bmm_batch_size} is not divisible by world size {self.world_size}. " @@ -211,10 +257,11 @@ def slice_tensor(t: torch.Tensor) -> torch.Tensor: handle_tensor(node, rhs_tensor, 1, self.start_idx, self.end_idx) # Add all_gather node after BMM to collect results + dist_op, dist_final_args = self._get_dist_op("all_gather") with gm.graph.inserting_after(node): gather_node = gm.graph.call_function( - torch.ops.auto_deploy.torch_dist_all_gather, - args=(node, 0), # Gather along batch dimension (0) + dist_op, + args=(node, 0, *dist_final_args), # Gather along batch dimension (0) ) node.replace_all_uses_with(gather_node) gather_node.replace_input_with(gather_node, node) @@ -242,7 +289,8 @@ def validate(self, gm: GraphModule = None, node: Node = None) -> bool: def apply(self, gm: GraphModule, node: Node) -> None: """Apply EP sharding transformation to the graph module.""" - _insert_sharded_moe(gm, node, self.rank, self.world_size) + dist_op, dist_final_args = self._get_dist_op("all_reduce") + _insert_sharded_moe(gm, node, dist_op, dist_final_args, self.rank, self.world_size) class ShardingConfig(BaseModel): @@ -323,7 +371,8 @@ def _insert_sharded_matmul( dim: int, rank: int, world_size: int, - add_dist: bool = False, + dist_op: Optional[Callable] = None, + dist_extra_args: Optional[Tuple] = None, min_local_shape: int = 1, ) -> None: """Replace the matmul node with a new matmul node that accepts sharded weights. @@ -331,7 +380,7 @@ def _insert_sharded_matmul( The state_dict is also updated to contain the sharded weights. """ assert dim in [0, 1], "Only dim 0 and 1 are supported for sharding" - assert add_dist or dim == 0, "For dim=1 sharding, dist_op is required." + assert dist_op or dim == 0, "For dim=1 sharding, dist_op is required." quantization_impl = QuantizationImpl.create(node) @@ -428,20 +477,13 @@ def set_new_param(submod: nn.Module, param_key: str, remove: bool = False) -> to ) ) - # no comm node needed for single device - if not add_dist: + # no dist op provided + if dist_op is None: return - # figure out the right dist op - dist_lookup = { - 0: (torch.ops.auto_deploy.torch_dist_all_gather, -1), - 1: (torch.ops.auto_deploy.torch_dist_all_reduce,), - } - fn_dist, *dist_args = dist_lookup[dim] - - # add reduction node + # add dist op node with gm.graph.inserting_after(node): - dist_node = gm.graph.call_function(fn_dist, args=(node, *dist_args)) + dist_node = gm.graph.call_function(dist_op, args=(node, *(dist_extra_args or ()))) node.replace_all_uses_with(dist_node) dist_node.replace_input_with(dist_node, node) @@ -690,7 +732,7 @@ def detect_dp_bmm_shard( base_size = bmm_batch_size // world_size remainder = bmm_batch_size % world_size - # NOTE: our torch.ops.auto_deploy.torch_dist_all_gather doesn't support uneven splits at the moment. + # NOTE: our torch.ops.auto_deploy.torch_all_gather doesn't support uneven splits at the moment. if remainder: ad_logger.warning( f"BMM batch size {bmm_batch_size} is not divisible by world size {world_size}. " @@ -764,6 +806,8 @@ def detect_ep_shard( def _insert_sharded_moe( gm: GraphModule, node: Node, + dist_op: OpOverloadPacket, + dist_final_args: Tuple, rank: int, world_size: int, ): @@ -841,8 +885,6 @@ def get_partition(lst, world_size, rank): # -- add an all_reduce node -- with gm.graph.inserting_after(node): - dist_node = gm.graph.call_function( - torch.ops.auto_deploy.torch_dist_all_reduce, args=(node,) - ) + dist_node = gm.graph.call_function(dist_op, args=(node, *dist_final_args)) node.replace_all_uses_with(dist_node) dist_node.replace_input_with(dist_node, node) diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/visualization.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/visualization.py index aaf77ac8e8c..ab2884991f7 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/visualization.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/library/visualization.py @@ -67,10 +67,10 @@ def add_outputs_metadata(self, fx_node: torch.fx.node.Node, node: GraphNode): # TODO(yudong): make custom_ops configurable CUSTOM_OPS = ( - torch.ops.auto_deploy.torch_dist_all_reduce.default, + torch.ops.auto_deploy.torch_all_reduce.default, torch.ops.aten.slice.Tensor, torch.ops.auto_deploy.triton_attention_fused_mha_with_cache.default, - torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce.default, + torch.ops.auto_deploy.torch_fused_linear_all_reduce.default, torch.ops.auto_deploy.torch_linear_simple.default, torch.ops.aten.split_with_sizes.default, ) diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/transform.py b/tensorrt_llm/_torch/auto_deploy/transformations/transform.py index c2ddb00dd15..633dfc87a59 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/transform.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/transform.py @@ -21,8 +21,8 @@ detect_ep_shard, eliminate_redundant_transposes, fuse_allreduce_residual_rmsnorm, - fuse_collectives, fuse_rmsnorm, + fuse_torch_allreduce, insert_cached_attention, match_attention_layout, match_attention_pattern, @@ -150,7 +150,7 @@ def __call__(self, cm: CachedSequenceInterface) -> nn.Module: fuse_allreduce_residual_rmsnorm(egm) # check if we can fuse collectives - fuse_collectives(egm) + fuse_torch_allreduce(egm) # TODO (lucaslie): add backend selection as part of configurable inference optimizers # check if we can fuse rmsnorm diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index 48f06c70e60..c5c6a4932b2 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -262,15 +262,6 @@ def is_bmm_op(node: Node, include_quantization: bool = False) -> bool: return is_op(node, dist_ops) -def is_dist_op(node: Node) -> bool: - """Check if the node is a distributed op.""" - dist_ops = { - torch.ops.auto_deploy.torch_dist_all_gather, - torch.ops.auto_deploy.torch_dist_all_reduce, - } - return is_op(node, dist_ops) - - def get_all_input_output_nodes(graph: Graph) -> Tuple[List[Node], List[Node]]: input_nodes: List[Node] = graph.find_nodes(op="placeholder") output_nodes: List[Node] = graph.find_nodes(op="output") diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/custom_ops/test_dist.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/custom_ops/test_dist.py index d4c8091158a..bf1f8e18a3a 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/custom_ops/test_dist.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/custom_ops/test_dist.py @@ -9,14 +9,14 @@ def _run_all_reduce_test(rank, world_size): x = torch.ones(10, 10).to("cuda") - y = torch.ops.auto_deploy.torch_dist_all_reduce(x) + y = torch.ops.auto_deploy.torch_all_reduce(x) assert torch.equal(x * world_size, y) def _run_all_gather_test(rank, world_size): x = torch.ones(10, 10).to("cuda") - y = torch.ops.auto_deploy.torch_dist_all_gather(x) + y = torch.ops.auto_deploy.torch_all_gather(x) assert torch.sum(y) == world_size * torch.sum(x) assert y.shape == (world_size * x.shape[0], *x.shape[1:]) diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py index c81ca0ae1c4..9647b718925 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py @@ -5,8 +5,7 @@ from _dist_test_utils import get_device_counts from torch.export import export -from tensorrt_llm._torch.auto_deploy.distributed import common as dist -from tensorrt_llm._torch.auto_deploy.distributed.trtllm import is_trtllm_op_available +from tensorrt_llm._torch.auto_deploy import distributed as dist from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm from tensorrt_llm._torch.auto_deploy.transformations.library.collectives import ( fuse_allreduce_residual_rmsnorm, @@ -37,14 +36,14 @@ def __init__(self, hidden_size, dtype): self.norm = RMSNorm(hidden_size, 1e-5, dtype) def forward(self, x, residual): - x = torch.ops.auto_deploy.torch_dist_all_reduce(x) + x = torch.ops.auto_deploy.trtllm_all_reduce(x) y = x + residual normed = self.norm(y) return normed, y def _test_allreduce_fusion(port: int): - if not is_trtllm_op_available(): + if not dist.is_trtllm_dist_available(): pytest.skip("Require trtllm ops to run test_allreduce_fusion.") _, _ = dist.initialize_or_skip(port=port) @@ -73,7 +72,7 @@ def _test_allreduce_fusion(port: int): # Check if fused node in the graph has_fused_node = False for node in gm.graph.nodes: - if is_op(node, torch.ops.dist.fused_allreduce_residual_rmsnorm): + if is_op(node, torch.ops.auto_deploy.trtllm_fused_allreduce_residual_rmsnorm): has_fused_node = True assert has_fused_node, "Fused node not found." @@ -98,4 +97,7 @@ def test_allreduce_fusion(device_count): n_workers = device_count mpi_pool = MpiPoolSession(n_workers=n_workers) - mpi_pool.submit_sync(_test_allreduce_fusion, port=port) + try: + mpi_pool.submit_sync(_test_allreduce_fusion, port=port) + finally: + mpi_pool.shutdown() diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py index ab135aa28a1..a0ffdea0ef2 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py @@ -75,7 +75,7 @@ def transform_func(gm) -> None: sharding_transform_executor(gm, sharding_config) # now run the test - op_expected = getattr(torch.ops.auto_deploy, "torch_dist_all_gather") + op_expected = getattr(torch.ops.auto_deploy, "torch_all_gather") run_test( model, x, diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_collective_fusion.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_collective_fusion.py index 4aa1a875c42..51742537ac3 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_collective_fusion.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_collective_fusion.py @@ -13,7 +13,7 @@ import tensorrt_llm._torch.auto_deploy.distributed.common as dist_common from tensorrt_llm._torch.auto_deploy.custom_ops.quant import FP8Linear -from tensorrt_llm._torch.auto_deploy.transformations.library import fuse_collectives +from tensorrt_llm._torch.auto_deploy.transformations.library import fuse_torch_allreduce from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op @@ -26,8 +26,8 @@ def __init__(self, in_features, out_features, bias, cls): self.linear2 = cls(4 * in_features, out_features, bias=bias) def forward(self, x): - y = F.relu(torch.ops.auto_deploy.torch_dist_all_reduce(self.linear1(x))) - return torch.ops.auto_deploy.torch_dist_all_reduce(self.linear2(y)) + y = F.relu(torch.ops.auto_deploy.torch_all_reduce(self.linear1(x))) + return torch.ops.auto_deploy.torch_all_reduce(self.linear2(y)) def _run_job( @@ -58,14 +58,14 @@ def _get_expected_num_params(num_p_og: int) -> int: def check_transformed_graph(gm): return any(is_op(n, op_expected) for n in gm.graph.nodes) and not any( - is_op(n, torch.ops.auto_deploy.torch_dist_all_reduce) for n in gm.graph.nodes + is_op(n, torch.ops.auto_deploy.torch_all_reduce) for n in gm.graph.nodes ) # now run the test run_test( model, x, - transform=fuse_collectives, + transform=fuse_torch_allreduce, check_transformed_graph=check_transformed_graph, _get_expected_num_params=_get_expected_num_params, test_load_hook=False, @@ -76,7 +76,7 @@ def check_transformed_graph(gm): @pytest.mark.parametrize( "linear_cls, dist_op_expected", ( - (nn.Linear, "auto_deploy.trtllm_dist_fused_linear_all_reduce"), + (nn.Linear, "auto_deploy.torch_fused_linear_all_reduce"), pytest.param( FP8Linear, "auto_deploy.torch_quant_fused_fp8_linear_all_reduce", diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py index 19cce483297..f79d0788caf 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py @@ -44,7 +44,7 @@ def transform_func(gm) -> None: detect_ep_shard(gm, rank, world_size, sharding_config) sharding_transform_executor(gm, sharding_config) - op_expected = torch.ops.auto_deploy.torch_dist_all_reduce + op_expected = torch.ops.auto_deploy.torch_all_reduce run_test( model, diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py index 9e33bef4a91..4d4108ca18c 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py @@ -115,7 +115,7 @@ def _run_job( def _get_expected_num_params(num_p_og: int) -> int: num_update = 0 - if bias and dist_op_expected == "torch_dist_all_reduce": + if bias and dist_op_expected == "torch_all_reduce": num_p_og -= num_features num_update = num_features * (rank == world_size - 1) @@ -275,9 +275,9 @@ def _run_pattern_detection_job( @pytest.mark.parametrize( "model_cls, dist_op_expected", ( - (MLP, "torch_dist_all_reduce"), - (nn.Linear, "torch_dist_all_gather"), - (GQA_Block, "torch_dist_all_reduce"), + (MLP, "torch_all_reduce"), + (nn.Linear, "torch_all_gather"), + (GQA_Block, "torch_all_reduce"), ), ) def test_sharding(model_cls: Type[nn.Module], dist_op_expected: str, bias: bool, device_count: int): @@ -292,9 +292,9 @@ def test_sharding(model_cls: Type[nn.Module], dist_op_expected: str, bias: bool, @pytest.mark.parametrize( "model_cls, dist_op_expected", ( - (MLP, "torch_dist_all_reduce"), - (nn.Linear, "torch_dist_all_gather"), - (GQA_Block, "torch_dist_all_reduce"), + (MLP, "torch_all_reduce"), + (nn.Linear, "torch_all_gather"), + (GQA_Block, "torch_all_reduce"), ), ) def test_sharding_pattern_detection(