Skip to content

Commit d272f1a

Browse files
authored
[TRTLLM-8821][feat] Apply AutoTuner to AllReduce Op for strategy tuning. (#8531)
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
1 parent 2f768b7 commit d272f1a

File tree

13 files changed

+409
-179
lines changed

13 files changed

+409
-179
lines changed

cpp/tensorrt_llm/common/customAllReduceUtils.h

Lines changed: 4 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -40,50 +40,12 @@ inline size_t getMaxRequiredWorkspaceSize(int worldSize) noexcept
4040
{
4141
return common::getEnvAllReduceWorkspaceSize();
4242
}
43-
if (worldSize <= 2)
43+
char const* envWorkspaceSize = std::getenv("TRTLLM_ALLREDUCE_FUSION_WORKSPACE_SIZE");
44+
if (envWorkspaceSize != nullptr)
4445
{
45-
return 16 * 1000 * 1000;
46-
}
47-
return 8 * 1000 * 1000;
48-
}
49-
50-
// (SM major_version, TP_size) -> (NCCL_num_token_threshold, TWO_SHOT_numel_threshold)
51-
inline std::unordered_map<int, std::unordered_map<int, std::pair<size_t, size_t>>> HeuristicThresholdLP{
52-
{90,
53-
{
54-
{2, {4096, 4096 * 4096}},
55-
{4, {4096, 1024 * 1024}},
56-
{8, {2048, 512 * 512}},
57-
}},
58-
{100,
59-
{
60-
{2, {4096, 4096 * 4096}},
61-
{4, {4096, 1024 * 2048}},
62-
{8, {4096, 1024 * 1024}},
63-
}},
64-
};
65-
66-
inline AllReduceStrategyType SelectStrategyLP(size_t seq_len, size_t hidden_size, int world_size, AllReduceFusionOp op)
67-
{
68-
// The heuristic is based on the following assumptions:
69-
// __________________________________
70-
// | \ TWO-SHOT zone |
71-
// | ONE-SHOT zone \ | NCCL zone
72-
// |_______________________\______|___
73-
// sm_major is 90 or 100
74-
75-
auto const sm_major = std::min(100, std::max(90, tensorrt_llm::common::getSMVersion()));
76-
77-
auto const [nccl_num_token_threshold, two_shot_numel_threshold] = HeuristicThresholdLP[sm_major][world_size];
78-
auto const message_size = seq_len * hidden_size;
79-
if (message_size >= two_shot_numel_threshold)
80-
{
81-
return AllReduceStrategyType::TWOSHOT;
82-
}
83-
else
84-
{
85-
return AllReduceStrategyType::ONESHOT;
46+
return static_cast<size_t>(std::atoi(envWorkspaceSize));
8647
}
48+
return 67108864; // 64 MiB
8749
}
8850

8951
// use 1D vector to store the best strategy instead of a map for each sm version

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from tensorrt_llm._torch.attention_backend.interface import AttentionRuntimeFeatures
2525
from tensorrt_llm._torch.auto_deploy.utils._graph import get_input_embeddings, get_lm_head_weights
26+
from tensorrt_llm._torch.autotuner import AutoTuner
2627
from tensorrt_llm._torch.models.modeling_speculative import Eagle3ForCausalLM
2728
from tensorrt_llm._torch.pyexecutor._util import (
2829
_create_kv_cache_manager,
@@ -1008,6 +1009,10 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
10081009
torch.cuda.set_device(rank)
10091010
port = mpi_dist.broadcast(dist.get_free_port()) # use MPI broadcast to pick a free port
10101011
dist.initialize_or_skip(rank, world_size, port)
1012+
1013+
# Setup AutoTuner with distributed state for allreduce autotuning
1014+
AutoTuner.get().setup_distributed_state(dist_mapping, mpi_dist)
1015+
10111016
# some config
10121017
assert ad_config.max_beam_width <= 1, "_autodeploy + beam_search is not supported"
10131018

tensorrt_llm/_torch/autotuner.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import tensorrt_llm
2121
from tensorrt_llm._torch.distributed import Distributed
22+
from tensorrt_llm._utils import nvtx_range
2223
from tensorrt_llm.bindings.internal.runtime import delay_kernel
2324
from tensorrt_llm.logger import logger
2425
from tensorrt_llm.mapping import Mapping
@@ -856,8 +857,10 @@ def choose_one(
856857
custom_op, runners, p.get_opt_shapes(), tuning_config)
857858
if not is_cache_hit:
858859
# Initialize runner and tactic as None in case of no valid tactic or runners are found
859-
best_runner_id, best_tactic, min_time, has_tuning_failure_occurred = self._profile_runners(
860-
custom_op, runners, tensors, p, tuning_config, **kwargs)
860+
with nvtx_range(f"{custom_op}, shape {p.get_opt_shapes()}"):
861+
best_runner_id, best_tactic, min_time, has_tuning_failure_occurred = self._profile_runners(
862+
custom_op, runners, tensors, p, tuning_config,
863+
**kwargs)
861864
new_tuning_failure_occurred = new_tuning_failure_occurred or has_tuning_failure_occurred
862865

863866
self._maybe_sync_cache_data(tuning_config.distributed_tuning_strategy,
@@ -894,6 +897,13 @@ def _profile_runners(
894897
tuning_config: TuningConfig,
895898
**kwargs,
896899
) -> float:
900+
"""Profile runners and select the best tactic.
901+
902+
For multi-rank profiling, only rank 0 performs the actual profiling
903+
to avoid sync issues when different ranks select different tactics.
904+
The results are then broadcasted to all other ranks.
905+
"""
906+
897907
min_time = float('inf')
898908
has_tuning_failure_occurred = False
899909
best_runner_id, best_tactic = None, None
@@ -921,14 +931,15 @@ def _profile_runners(
921931

922932
for tac in valid_tactics:
923933
try:
924-
time_measured = self._profile_single_kernel(
925-
runner=runner,
926-
inputs=input_tensors,
927-
tactic=tac,
928-
tuning_config=tuning_config,
929-
use_cuda_graph=tuning_config.use_cuda_graph,
930-
**kwargs,
931-
)
934+
with nvtx_range(f"r{runner_id}, tactic {tac}"):
935+
time_measured = self._profile_single_kernel(
936+
runner=runner,
937+
inputs=input_tensors,
938+
tactic=tac,
939+
tuning_config=tuning_config,
940+
use_cuda_graph=tuning_config.use_cuda_graph,
941+
**kwargs,
942+
)
932943
except Exception as e:
933944
# Handle None tensors for optional inputs
934945
shapes = self._get_input_sizes(input_tensors)
@@ -1038,10 +1049,13 @@ def pure_profile(stream: torch.cuda.Stream, repeat: int):
10381049
)
10391050

10401051
stream.synchronize()
1052+
if tuning_config.distributed_tuning_strategy == DistributedTuningStrategy.MERGE:
1053+
# Currently only AllReduce will use this strategy, and only MPI parallel will enable tuning.
1054+
# TODO: Unified tp barrier for both MPIDist and TorchDist.
1055+
if hasattr(self._dist, "tp_comm"):
1056+
self._dist.tp_comm.barrier()
10411057

10421058
# Delay the profiled kernel launch to eliminate affects of host time overhead in profiling.
1043-
# TODO: This is build time sensitive, O(tactic_num * impl_num * num_profile * tunable_ops)
1044-
# Consider apply a preprofiling to estimate the kernel execution time, then decide the necessity.
10451059
if use_cuda_graph:
10461060
delay_kernel(self._CUDA_GRAPH_DELAY_MICRO_SECS, stream)
10471061
else:
@@ -1064,6 +1078,7 @@ def pure_profile(stream: torch.cuda.Stream, repeat: int):
10641078

10651079
return start.elapsed_time(end) / repeat
10661080

1081+
# warm up, no timing
10671082
for _ in range(self.warmup):
10681083
runner(input_tensor_batches[-1], tactic=tactic, **kwargs)
10691084

tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py

Lines changed: 54 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from operator import getitem
2-
from typing import List, Optional
2+
from typing import Callable, List, Optional
33

44
import torch
55
from torch._inductor.pattern_matcher import (MULTIPLE, CallFunction, Ignored,
@@ -14,13 +14,13 @@
1414
from tensorrt_llm.mapping import Mapping
1515

1616

17-
def register_ar_residual_norm(custom_pass: PatternMatcherPass,
18-
mapping: Mapping):
17+
def register_ar_residual_norm(custom_pass: PatternMatcherPass, mapping: Mapping,
18+
allreduce_func: Callable):
1919
residual_key = KeywordArg("residual")
2020
trtllm_allreduce_default = CallFunction(
21-
torch.ops.trtllm.allreduce.default, KeywordArg("input"), None, None,
22-
None, None, KeywordArg("workspace"), mapping.tp_group,
23-
KeywordArg("strategy"), int(AllReduceFusionOp.NONE), Ignored(),
21+
allreduce_func.default, KeywordArg("input"), None, None, None, None,
22+
KeywordArg("workspace"), mapping.tp_group, KeywordArg("strategy"),
23+
int(AllReduceFusionOp.NONE), Ignored(),
2424
KeywordArg("trigger_completion_at_end"))
2525
getitem_x = CallFunction(getitem, trtllm_allreduce_default, 0)
2626
add_Tensor = CallFunction(aten.add.Tensor,
@@ -56,7 +56,7 @@ def target_pattern(
5656
eps: float,
5757
trigger_completion_at_end: bool,
5858
):
59-
all_reduce_output = torch.ops.trtllm.allreduce(
59+
all_reduce_output = allreduce_func(
6060
input, residual, norm_weight, None, None, workspace,
6161
mapping.tp_group, int(strategy),
6262
int(AllReduceFusionOp.RESIDUAL_RMS_NORM), float(eps),
@@ -111,10 +111,11 @@ def check_non_ub_strategy(match, strategy_node) -> bool:
111111

112112

113113
def register_ar_residual_norm_out_fp8_quant(custom_pass: PatternMatcherPass,
114-
mapping: Mapping):
114+
mapping: Mapping,
115+
allreduce_func: Callable):
115116
input_node = KeywordArg("input")
116117
strategy_node = KeywordArg("strategy")
117-
allreduce_default = CallFunction(torch.ops.trtllm.allreduce.default,
118+
allreduce_default = CallFunction(allreduce_func.default,
118119
input_node,
119120
KeywordArg("residual"),
120121
KeywordArg("gamma"),
@@ -165,7 +166,7 @@ def target_pattern(
165166
scale: torch.Tensor,
166167
trigger_completion_at_end: bool,
167168
):
168-
allreduce = torch.ops.trtllm.allreduce(
169+
allreduce = allreduce_func(
169170
input, residual, gamma, scale, None, workspace, mapping.tp_group,
170171
int(strategy),
171172
int(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_FP8), float(eps),
@@ -188,10 +189,11 @@ def extra_check(match: Match) -> bool:
188189

189190

190191
def register_ar_residual_norm_fp8_quant(custom_pass: PatternMatcherPass,
191-
mapping: Mapping):
192+
mapping: Mapping,
193+
allreduce_func: Callable):
192194
input_node = KeywordArg("input")
193195
strategy_node = KeywordArg("strategy")
194-
allreduce_default = CallFunction(torch.ops.trtllm.allreduce.default,
196+
allreduce_default = CallFunction(allreduce_func.default,
195197
input_node,
196198
KeywordArg("residual"),
197199
KeywordArg("gamma"),
@@ -242,7 +244,7 @@ def target_pattern(
242244
scale: torch.Tensor,
243245
trigger_completion_at_end: bool,
244246
):
245-
allreduce = torch.ops.trtllm.allreduce(
247+
allreduce = allreduce_func(
246248
input, residual, gamma, scale, None, workspace, mapping.tp_group,
247249
int(strategy), int(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8),
248250
float(eps), trigger_completion_at_end)
@@ -264,10 +266,11 @@ def extra_check(match: Match) -> bool:
264266

265267

266268
def register_ar_residual_norm_out_fp4_quant(custom_pass: PatternMatcherPass,
267-
mapping: Mapping):
269+
mapping: Mapping,
270+
allreduce_func: Callable):
268271
input_node = KeywordArg("input")
269272
strategy_node = KeywordArg("strategy")
270-
allreduce_default = CallFunction(torch.ops.trtllm.allreduce.default,
273+
allreduce_default = CallFunction(allreduce_func.default,
271274
input_node,
272275
KeywordArg("residual"),
273276
KeywordArg("gamma"),
@@ -313,7 +316,7 @@ def target_pattern(
313316
scale: torch.Tensor,
314317
trigger_completion_at_end: bool,
315318
):
316-
allreduce = torch.ops.trtllm.allreduce(
319+
allreduce = allreduce_func(
317320
input, residual, gamma, scale, None, workspace, mapping.tp_group,
318321
int(strategy),
319322
int(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4),
@@ -336,10 +339,11 @@ def extra_check(match: Match) -> bool:
336339

337340

338341
def register_ar_residual_norm_fp4_quant(custom_pass: PatternMatcherPass,
339-
mapping: Mapping):
342+
mapping: Mapping,
343+
allreduce_func: Callable):
340344
input_node = KeywordArg("input")
341345
strategy_node = KeywordArg("strategy")
342-
allreduce_default = CallFunction(torch.ops.trtllm.allreduce.default,
346+
allreduce_default = CallFunction(allreduce_func.default,
343347
input_node,
344348
KeywordArg("residual"),
345349
KeywordArg("gamma"),
@@ -385,7 +389,7 @@ def target_pattern(
385389
scale: torch.Tensor,
386390
trigger_completion_at_end: bool,
387391
):
388-
allreduce = torch.ops.trtllm.allreduce(
392+
allreduce = allreduce_func(
389393
input, residual, gamma, scale, None, workspace, mapping.tp_group,
390394
int(strategy), int(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4),
391395
float(eps), trigger_completion_at_end)
@@ -407,17 +411,20 @@ def extra_check(match: Match) -> bool:
407411

408412

409413
def register_ub_patterns(custom_passes: List[PatternMatcherPass],
410-
mapping: Mapping):
414+
mapping: Mapping, allreduce_func: Callable):
411415

412416
def register_convert_supported_ar_to_ub(custom_pass: PatternMatcherPass):
413417
strategy = int(AllReduceStrategy.AUTO)
414418
input_node = KeywordArg('input')
415419
fusion = KeywordArg('fusion_op')
416-
trtllm_allreduce_default = CallFunction(
417-
torch.ops.trtllm.allreduce.default, input_node,
418-
KeywordArg('residual_in'), KeywordArg('gamma'), KeywordArg('scale'),
419-
None, Ignored(), mapping.tp_group, strategy, fusion,
420-
KeywordArg('eps'), Ignored())
420+
trtllm_allreduce_default = CallFunction(allreduce_func.default,
421+
input_node,
422+
KeywordArg('residual_in'),
423+
KeywordArg('gamma'),
424+
KeywordArg('scale'), None,
425+
Ignored(), mapping.tp_group,
426+
strategy, fusion,
427+
KeywordArg('eps'), Ignored())
421428

422429
def empty_convert_supported_ar_to_ub(
423430
input: torch.Tensor,
@@ -667,7 +674,7 @@ def register_ub_finalize_patterns(custom_pass: PatternMatcherPass):
667674
torch.ops.trtllm.userbuffers_allreduce_finalize.default,
668675
KeywordArg("sharded_residual"), False)
669676
trtllm_allreduce_default = CallFunction(
670-
torch.ops.trtllm.allreduce.default, KeywordArg("input"),
677+
torch.ops.trtllm.allreduce, KeywordArg("input"),
671678
trtllm_userbuffers_allreduce_finalize_default, KeywordArg("gamma"),
672679
KeywordArg("scale"), Ignored(), Ignored(), mapping.tp_group,
673680
int(AllReduceStrategy.UB), KeywordArg("fusion_op"),
@@ -718,15 +725,28 @@ def target_finalize_pattern(
718725

719726
def register_ar_fusions(custom_passes: List[PatternMatcherPass],
720727
mapping: Mapping, enable_ub: bool):
721-
register_ar_residual_norm(custom_passes[-1], mapping)
728+
register_ar_residual_norm(custom_passes[-1], mapping,
729+
torch.ops.trtllm.allreduce)
730+
register_ar_residual_norm(custom_passes[-1], mapping,
731+
torch.ops.trtllm.tunable_allreduce)
722732

723733
custom_passes.append(PatternMatcherPass())
724-
register_ar_residual_norm_fp8_quant(custom_passes[-1], mapping)
725-
register_ar_residual_norm_fp4_quant(custom_passes[-1], mapping)
726-
# AR-Residual-Norm-Out-Quant-X is not supported by Userbuffers kernel.
727-
if not enable_ub:
728-
register_ar_residual_norm_out_fp8_quant(custom_passes[-1], mapping)
729-
register_ar_residual_norm_out_fp4_quant(custom_passes[-1], mapping)
734+
for allreduce_func in [
735+
torch.ops.trtllm.allreduce, torch.ops.trtllm.tunable_allreduce
736+
]:
737+
register_ar_residual_norm_fp8_quant(custom_passes[-1], mapping,
738+
allreduce_func)
739+
register_ar_residual_norm_fp4_quant(custom_passes[-1], mapping,
740+
allreduce_func)
741+
742+
# AR-Residual-Norm-Out-Quant-X is not supported by Userbuffers kernel.
743+
if not enable_ub:
744+
register_ar_residual_norm_out_fp8_quant(custom_passes[-1], mapping,
745+
allreduce_func)
746+
register_ar_residual_norm_out_fp4_quant(custom_passes[-1], mapping,
747+
allreduce_func)
730748

731749
if enable_ub:
732-
register_ub_patterns(custom_passes, mapping)
750+
register_ub_patterns(custom_passes, mapping, torch.ops.trtllm.allreduce)
751+
register_ub_patterns(custom_passes, mapping,
752+
torch.ops.trtllm.tunable_allreduce)

0 commit comments

Comments
 (0)