Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 4 additions & 42 deletions cpp/tensorrt_llm/common/customAllReduceUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,50 +40,12 @@ inline size_t getMaxRequiredWorkspaceSize(int worldSize) noexcept
{
return common::getEnvAllReduceWorkspaceSize();
}
if (worldSize <= 2)
char const* envWorkspaceSize = std::getenv("TRTLLM_ALLREDUCE_FUSION_WORKSPACE_SIZE");
if (envWorkspaceSize != nullptr)
{
return 16 * 1000 * 1000;
}
return 8 * 1000 * 1000;
}

// (SM major_version, TP_size) -> (NCCL_num_token_threshold, TWO_SHOT_numel_threshold)
inline std::unordered_map<int, std::unordered_map<int, std::pair<size_t, size_t>>> HeuristicThresholdLP{
{90,
{
{2, {4096, 4096 * 4096}},
{4, {4096, 1024 * 1024}},
{8, {2048, 512 * 512}},
}},
{100,
{
{2, {4096, 4096 * 4096}},
{4, {4096, 1024 * 2048}},
{8, {4096, 1024 * 1024}},
}},
};

inline AllReduceStrategyType SelectStrategyLP(size_t seq_len, size_t hidden_size, int world_size, AllReduceFusionOp op)
{
// The heuristic is based on the following assumptions:
// __________________________________
// | \ TWO-SHOT zone |
// | ONE-SHOT zone \ | NCCL zone
// |_______________________\______|___
// sm_major is 90 or 100

auto const sm_major = std::min(100, std::max(90, tensorrt_llm::common::getSMVersion()));

auto const [nccl_num_token_threshold, two_shot_numel_threshold] = HeuristicThresholdLP[sm_major][world_size];
auto const message_size = seq_len * hidden_size;
if (message_size >= two_shot_numel_threshold)
{
return AllReduceStrategyType::TWOSHOT;
}
else
{
return AllReduceStrategyType::ONESHOT;
return static_cast<size_t>(std::atoi(envWorkspaceSize));
}
return 67108864; // 64 MiB
}

// use 1D vector to store the best strategy instead of a map for each sm version
Expand Down
5 changes: 5 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from tensorrt_llm._torch.attention_backend.interface import AttentionRuntimeFeatures
from tensorrt_llm._torch.auto_deploy.utils._graph import get_input_embeddings, get_lm_head_weights
from tensorrt_llm._torch.autotuner import AutoTuner
from tensorrt_llm._torch.models.modeling_speculative import Eagle3ForCausalLM
from tensorrt_llm._torch.pyexecutor._util import (
_create_kv_cache_manager,
Expand Down Expand Up @@ -1008,6 +1009,10 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
torch.cuda.set_device(rank)
port = mpi_dist.broadcast(dist.get_free_port()) # use MPI broadcast to pick a free port
dist.initialize_or_skip(rank, world_size, port)

# Setup AutoTuner with distributed state for allreduce autotuning
AutoTuner.get().setup_distributed_state(dist_mapping, mpi_dist)

# some config
assert ad_config.max_beam_width <= 1, "_autodeploy + beam_search is not supported"

Expand Down
39 changes: 27 additions & 12 deletions tensorrt_llm/_torch/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import tensorrt_llm
from tensorrt_llm._torch.distributed import Distributed
from tensorrt_llm._utils import nvtx_range
from tensorrt_llm.bindings.internal.runtime import delay_kernel
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
Expand Down Expand Up @@ -844,8 +845,10 @@ def choose_one(
custom_op, runners, p.get_opt_shapes(), tuning_config)
if not is_cache_hit:
# Initialize runner and tactic as None in case of no valid tactic or runners are found
best_runner_id, best_tactic, min_time, has_tuning_failure_occurred = self._profile_runners(
custom_op, runners, tensors, p, tuning_config, **kwargs)
with nvtx_range(f"{custom_op}, shape {p.get_opt_shapes()}"):
best_runner_id, best_tactic, min_time, has_tuning_failure_occurred = self._profile_runners(
custom_op, runners, tensors, p, tuning_config,
**kwargs)
new_tuning_failure_occurred = new_tuning_failure_occurred or has_tuning_failure_occurred

self._maybe_sync_cache_data(tuning_config.distributed_tuning_strategy,
Expand Down Expand Up @@ -882,6 +885,13 @@ def _profile_runners(
tuning_config: TuningConfig,
**kwargs,
) -> float:
"""Profile runners and select the best tactic.

For multi-rank profiling, only rank 0 performs the actual profiling
to avoid sync issues when different ranks select different tactics.
The results are then broadcasted to all other ranks.
"""

min_time = float('inf')
has_tuning_failure_occurred = False
best_runner_id, best_tactic = None, None
Expand Down Expand Up @@ -909,14 +919,15 @@ def _profile_runners(

for tac in valid_tactics:
try:
time_measured = self._profile_single_kernel(
runner=runner,
inputs=input_tensors,
tactic=tac,
tuning_config=tuning_config,
use_cuda_graph=tuning_config.use_cuda_graph,
**kwargs,
)
with nvtx_range(f"r{runner_id}, tactic {tac}"):
time_measured = self._profile_single_kernel(
runner=runner,
inputs=input_tensors,
tactic=tac,
tuning_config=tuning_config,
use_cuda_graph=tuning_config.use_cuda_graph,
**kwargs,
)
except Exception as e:
# Handle None tensors for optional inputs
shapes = self._get_input_sizes(input_tensors)
Expand Down Expand Up @@ -1026,10 +1037,13 @@ def pure_profile(stream: torch.cuda.Stream, repeat: int):
)

stream.synchronize()
if tuning_config.distributed_tuning_strategy == DistributedTuningStrategy.MERGE:
# Currently only AllReduce will use this strategy, and only MPI parallel will enable tuning.
# TODO: Unified tp barrier for both MPIDist and TorchDist.
if hasattr(self._dist, "tp_comm"):
self._dist.tp_comm.barrier()

# Delay the profiled kernel launch to eliminate affects of host time overhead in profiling.
# TODO: This is build time sensitive, O(tactic_num * impl_num * num_profile * tunable_ops)
# Consider apply a preprofiling to estimate the kernel execution time, then decide the necessity.
if use_cuda_graph:
delay_kernel(self._CUDA_GRAPH_DELAY_MICRO_SECS, stream)
else:
Expand All @@ -1052,6 +1066,7 @@ def pure_profile(stream: torch.cuda.Stream, repeat: int):

return start.elapsed_time(end) / repeat

# warm up, no timing
for _ in range(self.warmup):
runner(input_tensor_batches[-1], tactic=tactic, **kwargs)

Expand Down
88 changes: 54 additions & 34 deletions tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from operator import getitem
from typing import List, Optional
from typing import Callable, List, Optional

import torch
from torch._inductor.pattern_matcher import (MULTIPLE, CallFunction, Ignored,
Expand All @@ -14,13 +14,13 @@
from tensorrt_llm.mapping import Mapping


def register_ar_residual_norm(custom_pass: PatternMatcherPass,
mapping: Mapping):
def register_ar_residual_norm(custom_pass: PatternMatcherPass, mapping: Mapping,
allreduce_func: Callable):
residual_key = KeywordArg("residual")
trtllm_allreduce_default = CallFunction(
torch.ops.trtllm.allreduce.default, KeywordArg("input"), None, None,
None, None, KeywordArg("workspace"), mapping.tp_group,
KeywordArg("strategy"), int(AllReduceFusionOp.NONE), Ignored(),
allreduce_func.default, KeywordArg("input"), None, None, None, None,
KeywordArg("workspace"), mapping.tp_group, KeywordArg("strategy"),
int(AllReduceFusionOp.NONE), Ignored(),
KeywordArg("trigger_completion_at_end"))
getitem_x = CallFunction(getitem, trtllm_allreduce_default, 0)
add_Tensor = CallFunction(aten.add.Tensor,
Expand Down Expand Up @@ -56,7 +56,7 @@ def target_pattern(
eps: float,
trigger_completion_at_end: bool,
):
all_reduce_output = torch.ops.trtllm.allreduce(
all_reduce_output = allreduce_func(
input, residual, norm_weight, None, None, workspace,
mapping.tp_group, int(strategy),
int(AllReduceFusionOp.RESIDUAL_RMS_NORM), float(eps),
Expand Down Expand Up @@ -111,10 +111,11 @@ def check_non_ub_strategy(match, strategy_node) -> bool:


def register_ar_residual_norm_out_fp8_quant(custom_pass: PatternMatcherPass,
mapping: Mapping):
mapping: Mapping,
allreduce_func: Callable):
input_node = KeywordArg("input")
strategy_node = KeywordArg("strategy")
allreduce_default = CallFunction(torch.ops.trtllm.allreduce.default,
allreduce_default = CallFunction(allreduce_func.default,
input_node,
KeywordArg("residual"),
KeywordArg("gamma"),
Expand Down Expand Up @@ -165,7 +166,7 @@ def target_pattern(
scale: torch.Tensor,
trigger_completion_at_end: bool,
):
allreduce = torch.ops.trtllm.allreduce(
allreduce = allreduce_func(
input, residual, gamma, scale, None, workspace, mapping.tp_group,
int(strategy),
int(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_FP8), float(eps),
Expand All @@ -188,10 +189,11 @@ def extra_check(match: Match) -> bool:


def register_ar_residual_norm_fp8_quant(custom_pass: PatternMatcherPass,
mapping: Mapping):
mapping: Mapping,
allreduce_func: Callable):
input_node = KeywordArg("input")
strategy_node = KeywordArg("strategy")
allreduce_default = CallFunction(torch.ops.trtllm.allreduce.default,
allreduce_default = CallFunction(allreduce_func.default,
input_node,
KeywordArg("residual"),
KeywordArg("gamma"),
Expand Down Expand Up @@ -242,7 +244,7 @@ def target_pattern(
scale: torch.Tensor,
trigger_completion_at_end: bool,
):
allreduce = torch.ops.trtllm.allreduce(
allreduce = allreduce_func(
input, residual, gamma, scale, None, workspace, mapping.tp_group,
int(strategy), int(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8),
float(eps), trigger_completion_at_end)
Expand All @@ -264,10 +266,11 @@ def extra_check(match: Match) -> bool:


def register_ar_residual_norm_out_fp4_quant(custom_pass: PatternMatcherPass,
mapping: Mapping):
mapping: Mapping,
allreduce_func: Callable):
input_node = KeywordArg("input")
strategy_node = KeywordArg("strategy")
allreduce_default = CallFunction(torch.ops.trtllm.allreduce.default,
allreduce_default = CallFunction(allreduce_func.default,
input_node,
KeywordArg("residual"),
KeywordArg("gamma"),
Expand Down Expand Up @@ -313,7 +316,7 @@ def target_pattern(
scale: torch.Tensor,
trigger_completion_at_end: bool,
):
allreduce = torch.ops.trtllm.allreduce(
allreduce = allreduce_func(
input, residual, gamma, scale, None, workspace, mapping.tp_group,
int(strategy),
int(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4),
Expand All @@ -336,10 +339,11 @@ def extra_check(match: Match) -> bool:


def register_ar_residual_norm_fp4_quant(custom_pass: PatternMatcherPass,
mapping: Mapping):
mapping: Mapping,
allreduce_func: Callable):
input_node = KeywordArg("input")
strategy_node = KeywordArg("strategy")
allreduce_default = CallFunction(torch.ops.trtllm.allreduce.default,
allreduce_default = CallFunction(allreduce_func.default,
input_node,
KeywordArg("residual"),
KeywordArg("gamma"),
Expand Down Expand Up @@ -385,7 +389,7 @@ def target_pattern(
scale: torch.Tensor,
trigger_completion_at_end: bool,
):
allreduce = torch.ops.trtllm.allreduce(
allreduce = allreduce_func(
input, residual, gamma, scale, None, workspace, mapping.tp_group,
int(strategy), int(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4),
float(eps), trigger_completion_at_end)
Expand All @@ -407,17 +411,20 @@ def extra_check(match: Match) -> bool:


def register_ub_patterns(custom_passes: List[PatternMatcherPass],
mapping: Mapping):
mapping: Mapping, allreduce_func: Callable):

def register_convert_supported_ar_to_ub(custom_pass: PatternMatcherPass):
strategy = int(AllReduceStrategy.AUTO)
input_node = KeywordArg('input')
fusion = KeywordArg('fusion_op')
trtllm_allreduce_default = CallFunction(
torch.ops.trtllm.allreduce.default, input_node,
KeywordArg('residual_in'), KeywordArg('gamma'), KeywordArg('scale'),
None, Ignored(), mapping.tp_group, strategy, fusion,
KeywordArg('eps'), Ignored())
trtllm_allreduce_default = CallFunction(allreduce_func.default,
input_node,
KeywordArg('residual_in'),
KeywordArg('gamma'),
KeywordArg('scale'), None,
Ignored(), mapping.tp_group,
strategy, fusion,
KeywordArg('eps'), Ignored())

def empty_convert_supported_ar_to_ub(
input: torch.Tensor,
Expand Down Expand Up @@ -667,7 +674,7 @@ def register_ub_finalize_patterns(custom_pass: PatternMatcherPass):
torch.ops.trtllm.userbuffers_allreduce_finalize.default,
KeywordArg("sharded_residual"), False)
trtllm_allreduce_default = CallFunction(
torch.ops.trtllm.allreduce.default, KeywordArg("input"),
torch.ops.trtllm.allreduce, KeywordArg("input"),
trtllm_userbuffers_allreduce_finalize_default, KeywordArg("gamma"),
KeywordArg("scale"), Ignored(), Ignored(), mapping.tp_group,
int(AllReduceStrategy.UB), KeywordArg("fusion_op"),
Expand Down Expand Up @@ -718,15 +725,28 @@ def target_finalize_pattern(

def register_ar_fusions(custom_passes: List[PatternMatcherPass],
mapping: Mapping, enable_ub: bool):
register_ar_residual_norm(custom_passes[-1], mapping)
register_ar_residual_norm(custom_passes[-1], mapping,
torch.ops.trtllm.allreduce)
register_ar_residual_norm(custom_passes[-1], mapping,
torch.ops.trtllm.tunable_allreduce)

custom_passes.append(PatternMatcherPass())
register_ar_residual_norm_fp8_quant(custom_passes[-1], mapping)
register_ar_residual_norm_fp4_quant(custom_passes[-1], mapping)
# AR-Residual-Norm-Out-Quant-X is not supported by Userbuffers kernel.
if not enable_ub:
register_ar_residual_norm_out_fp8_quant(custom_passes[-1], mapping)
register_ar_residual_norm_out_fp4_quant(custom_passes[-1], mapping)
for allreduce_func in [
torch.ops.trtllm.allreduce, torch.ops.trtllm.tunable_allreduce
]:
register_ar_residual_norm_fp8_quant(custom_passes[-1], mapping,
allreduce_func)
register_ar_residual_norm_fp4_quant(custom_passes[-1], mapping,
allreduce_func)

# AR-Residual-Norm-Out-Quant-X is not supported by Userbuffers kernel.
if not enable_ub:
register_ar_residual_norm_out_fp8_quant(custom_passes[-1], mapping,
allreduce_func)
register_ar_residual_norm_out_fp4_quant(custom_passes[-1], mapping,
allreduce_func)

if enable_ub:
register_ub_patterns(custom_passes, mapping)
register_ub_patterns(custom_passes, mapping, torch.ops.trtllm.allreduce)
register_ub_patterns(custom_passes, mapping,
torch.ops.trtllm.tunable_allreduce)
Loading