|
| 1 | +"""TRT-LLM distributed operations and fused kernels. |
| 2 | +
|
| 3 | +This module defines atomic TRT-LLM-specific ops that use optimized kernels. |
| 4 | +The torch fallback variants are defined separately to enable multi-pattern matching. |
| 5 | +""" |
| 6 | + |
| 7 | +from typing import List, Optional |
| 8 | + |
| 9 | +import torch |
| 10 | + |
| 11 | +# use trtllm distributed ops to improve TP performance if possible |
| 12 | +from ....mapping import Mapping |
| 13 | +from ...distributed import AllReduce, allgather |
| 14 | +from ...modules.linear import AllReduceFusionOp, AllReduceParams, AllReduceStrategy |
| 15 | +from ..distributed.common import ReduceOp, get_rank_world_size, get_world_size, is_ompi |
| 16 | + |
| 17 | +# Cache AllReduce modules to avoid recreating on every call |
| 18 | +# This is critical for CUDA graph compatibility - recreating modules during |
| 19 | +# warmup causes hangs due to workspace allocation with CPU synchronization |
| 20 | +_allreduce_cache = {} |
| 21 | + |
| 22 | + |
| 23 | +def trtllm_allgather(tensor, dim, sizes=None): |
| 24 | + rank, world_size = get_rank_world_size() |
| 25 | + p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank) |
| 26 | + return allgather(tensor, p_config, dim=dim, sizes=sizes) |
| 27 | + |
| 28 | + |
| 29 | +def trtllm_allreduce(tensor, op, strategy: str, all_reduce_params=None): |
| 30 | + rank, world_size = get_rank_world_size() |
| 31 | + assert op == ReduceOp.SUM, "TRT-LLM all reduce only supports SUM op." |
| 32 | + |
| 33 | + # Convert string strategy to enum |
| 34 | + try: |
| 35 | + strategy_enum = getattr(AllReduceStrategy, strategy) |
| 36 | + except AttributeError: |
| 37 | + raise ValueError( |
| 38 | + f"Invalid allreduce strategy: {strategy}. " |
| 39 | + f"Valid options: AUTO, NCCL, ONESHOT, TWOSHOT, MIN_LATENCY, " |
| 40 | + f"LOWPRECISION, UB, MNNVL, NCCL_SYMMETRIC" |
| 41 | + ) |
| 42 | + |
| 43 | + # Cache key includes rank, world_size, dtype, and strategy to handle different configurations |
| 44 | + cache_key = (rank, world_size, tensor.dtype, strategy_enum) |
| 45 | + if cache_key not in _allreduce_cache: |
| 46 | + p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank) |
| 47 | + _allreduce_cache[cache_key] = AllReduce( |
| 48 | + mapping=p_config, strategy=strategy_enum, dtype=tensor.dtype |
| 49 | + ) |
| 50 | + |
| 51 | + torch_op = _allreduce_cache[cache_key] |
| 52 | + return torch_op(tensor, all_reduce_params=all_reduce_params) |
| 53 | + |
| 54 | + |
| 55 | +# ============================================================================ |
| 56 | +# TRT-LLM Backend Ops (MPI mode) |
| 57 | +# ============================================================================ |
| 58 | + |
| 59 | + |
| 60 | +@torch.library.custom_op( |
| 61 | + "auto_deploy::trtllm_dist_all_gather", mutates_args=(), device_types="cuda" |
| 62 | +) |
| 63 | +def trtllm_dist_all_gather( |
| 64 | + tensor: torch.Tensor, dim: int = 0, sizes: Optional[List[int]] = None |
| 65 | +) -> torch.Tensor: |
| 66 | + """All gather using TRT-LLM optimized backend. |
| 67 | +
|
| 68 | + This op always uses TRT-LLM's optimized allgather and is used in MPI mode. |
| 69 | + """ |
| 70 | + return trtllm_allgather(tensor, dim=dim, sizes=sizes) |
| 71 | + |
| 72 | + |
| 73 | +@trtllm_dist_all_gather.register_fake |
| 74 | +def trtllm_dist_all_gather_fake(tensor, dim=0, sizes=None): |
| 75 | + return torch.cat([torch.empty_like(tensor) for _ in range(get_world_size())], dim=dim) |
| 76 | + |
| 77 | + |
| 78 | +@torch.library.custom_op( |
| 79 | + "auto_deploy::trtllm_dist_all_reduce", mutates_args=(), device_types="cuda" |
| 80 | +) |
| 81 | +def trtllm_dist_all_reduce(t: torch.Tensor, strategy: str) -> torch.Tensor: |
| 82 | + """All_reduce using TRT-LLM optimized backend. Reduction op is SUM. |
| 83 | +
|
| 84 | + This op always uses TRT-LLM's optimized allreduce and is used in MPI mode. |
| 85 | + """ |
| 86 | + return trtllm_allreduce(t, op=ReduceOp.SUM, strategy=strategy) |
| 87 | + |
| 88 | + |
| 89 | +@trtllm_dist_all_reduce.register_fake |
| 90 | +def trtllm_dist_all_reduce_fake(tensor, strategy): |
| 91 | + return torch.empty_like(tensor) |
| 92 | + |
| 93 | + |
| 94 | +# TRT-LLM fused op (atomic - always uses TRT-LLM backend) |
| 95 | +@torch.library.custom_op( |
| 96 | + "dist::trtllm_fused_allreduce_residual_rmsnorm", mutates_args=(), device_types="cuda" |
| 97 | +) |
| 98 | +def trtllm_fused_allreduce_residual_rmsnorm( |
| 99 | + tensor: torch.Tensor, |
| 100 | + residual: torch.Tensor, |
| 101 | + norm_weight: torch.Tensor, |
| 102 | + eps: float, |
| 103 | + strategy: str, |
| 104 | +) -> tuple[torch.Tensor, torch.Tensor]: |
| 105 | + """Fused allreduce + residual + rmsnorm using TRT-LLM optimized kernel. |
| 106 | +
|
| 107 | + This op always uses TRT-LLM's fused kernel and is used in MPI mode. |
| 108 | + """ |
| 109 | + all_reduce_params = AllReduceParams( |
| 110 | + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, |
| 111 | + bias=None, |
| 112 | + residual=residual, |
| 113 | + norm_weight=norm_weight, |
| 114 | + eps=eps, |
| 115 | + ) |
| 116 | + return trtllm_allreduce( |
| 117 | + tensor, ReduceOp.SUM, strategy=strategy, all_reduce_params=all_reduce_params |
| 118 | + ) |
| 119 | + |
| 120 | + |
| 121 | +@trtllm_fused_allreduce_residual_rmsnorm.register_fake |
| 122 | +def trtllm_fused_allreduce_residual_rmsnorm_fake( |
| 123 | + tensor: torch.Tensor, |
| 124 | + residual: torch.Tensor, |
| 125 | + norm_weight: torch.Tensor, |
| 126 | + eps: float, |
| 127 | + strategy: str, |
| 128 | +) -> tuple[torch.Tensor, torch.Tensor]: |
| 129 | + return torch.empty_like(tensor), torch.empty_like(tensor) |
| 130 | + |
| 131 | + |
| 132 | +def is_trtllm_op_available(): |
| 133 | + """Check if TRT-LLM ops are available and running with MPI.""" |
| 134 | + return is_ompi() |
0 commit comments