Skip to content

Commit 47d1cf5

Browse files
committed
[TRTLLM-8129][feat] Apply AutoTuner to AllReduce Op for strategy tuning.
Enable autotune by default if using AUTO strategy. Signed-off-by: Yukun He <[email protected]>
1 parent 83e02ee commit 47d1cf5

File tree

10 files changed

+336
-141
lines changed

10 files changed

+336
-141
lines changed

cpp/tensorrt_llm/common/customAllReduceUtils.h

Lines changed: 4 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -40,50 +40,12 @@ inline size_t getMaxRequiredWorkspaceSize(int worldSize) noexcept
4040
{
4141
return common::getEnvAllReduceWorkspaceSize();
4242
}
43-
if (worldSize <= 2)
43+
char const* envWorkspaceSize = std::getenv("TRTLLM_ALLREDUCE_FUSION_WORKSPACE_SIZE");
44+
if (envWorkspaceSize != nullptr)
4445
{
45-
return 16 * 1000 * 1000;
46-
}
47-
return 8 * 1000 * 1000;
48-
}
49-
50-
// (SM major_version, TP_size) -> (NCCL_num_token_threshold, TWO_SHOT_numel_threshold)
51-
inline std::unordered_map<int, std::unordered_map<int, std::pair<size_t, size_t>>> HeuristicThresholdLP{
52-
{90,
53-
{
54-
{2, {4096, 4096 * 4096}},
55-
{4, {4096, 1024 * 1024}},
56-
{8, {2048, 512 * 512}},
57-
}},
58-
{100,
59-
{
60-
{2, {4096, 4096 * 4096}},
61-
{4, {4096, 1024 * 2048}},
62-
{8, {4096, 1024 * 1024}},
63-
}},
64-
};
65-
66-
inline AllReduceStrategyType SelectStrategyLP(size_t seq_len, size_t hidden_size, int world_size, AllReduceFusionOp op)
67-
{
68-
// The heuristic is based on the following assumptions:
69-
// __________________________________
70-
// | \ TWO-SHOT zone |
71-
// | ONE-SHOT zone \ | NCCL zone
72-
// |_______________________\______|___
73-
// sm_major is 90 or 100
74-
75-
auto const sm_major = std::min(100, std::max(90, tensorrt_llm::common::getSMVersion()));
76-
77-
auto const [nccl_num_token_threshold, two_shot_numel_threshold] = HeuristicThresholdLP[sm_major][world_size];
78-
auto const message_size = seq_len * hidden_size;
79-
if (message_size >= two_shot_numel_threshold)
80-
{
81-
return AllReduceStrategyType::TWOSHOT;
82-
}
83-
else
84-
{
85-
return AllReduceStrategyType::ONESHOT;
46+
return static_cast<size_t>(std::atoi(envWorkspaceSize));
8647
}
48+
return 67108864; // 64 MiB
8749
}
8850

8951
// use 1D vector to store the best strategy instead of a map for each sm version

tensorrt_llm/_torch/autotuner.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import tensorrt_llm
2020
from tensorrt_llm._torch.distributed import Distributed
21+
from tensorrt_llm._utils import nvtx_range
2122
from tensorrt_llm.bindings.internal.runtime import delay_kernel
2223
from tensorrt_llm.logger import logger
2324
from tensorrt_llm.mapping import Mapping
@@ -831,8 +832,10 @@ def choose_one(
831832
custom_op, runners, p.get_opt_shapes(), tuning_config)
832833
if not is_cache_hit:
833834
# Initialize runner and tactic as None in case of no valid tactic or runners are found
834-
best_runner_id, best_tactic, min_time, has_tuning_failure_occurred = self._profile_runners(
835-
custom_op, runners, tensors, p, tuning_config, **kwargs)
835+
with nvtx_range(f"{custom_op}, shape {p.get_opt_shapes()}"):
836+
best_runner_id, best_tactic, min_time, has_tuning_failure_occurred = self._profile_runners(
837+
custom_op, runners, tensors, p, tuning_config,
838+
**kwargs)
836839
new_tuning_failure_occurred = new_tuning_failure_occurred or has_tuning_failure_occurred
837840

838841
self._maybe_sync_cache_data(tuning_config.distributed_tuning_strategy,
@@ -869,6 +872,13 @@ def _profile_runners(
869872
tuning_config: TuningConfig,
870873
**kwargs,
871874
) -> float:
875+
"""Profile runners and select the best tactic.
876+
877+
For multi-rank profiling, only rank 0 performs the actual profiling
878+
to avoid sync issues when different ranks select different tactics.
879+
The results are then broadcasted to all other ranks.
880+
"""
881+
872882
min_time = float('inf')
873883
has_tuning_failure_occurred = False
874884
best_runner_id, best_tactic = None, None
@@ -896,14 +906,15 @@ def _profile_runners(
896906

897907
for tac in valid_tactics:
898908
try:
899-
time_measured = self._profile_single_kernel(
900-
runner=runner,
901-
inputs=input_tensors,
902-
tactic=tac,
903-
tuning_config=tuning_config,
904-
use_cuda_graph=tuning_config.use_cuda_graph,
905-
**kwargs,
906-
)
909+
with nvtx_range(f"r{runner_id}, tactic {tac}"):
910+
time_measured = self._profile_single_kernel(
911+
runner=runner,
912+
inputs=input_tensors,
913+
tactic=tac,
914+
tuning_config=tuning_config,
915+
use_cuda_graph=tuning_config.use_cuda_graph,
916+
**kwargs,
917+
)
907918
except Exception as e:
908919
# Handle None tensors for optional inputs
909920
shapes = self._get_input_sizes(input_tensors)
@@ -1015,13 +1026,14 @@ def pure_profile(stream: torch.cuda.Stream, repeat: int):
10151026
stream.synchronize()
10161027

10171028
# Delay the profiled kernel launch to eliminate affects of host time overhead in profiling.
1018-
# TODO: This is build time sensitive, O(tactic_num * impl_num * num_profile * tunable_ops)
1019-
# Consider apply a preprofiling to estimate the kernel execution time, then decide the necessity.
10201029
if use_cuda_graph:
10211030
delay_kernel(self._CUDA_GRAPH_DELAY_MICRO_SECS, stream)
10221031
else:
10231032
delay_kernel(self.stream_delay_micro_secs, stream)
10241033

1034+
if tuning_config.distributed_tuning_strategy == DistributedTuningStrategy.MERGE:
1035+
self._dist.barrier()
1036+
10251037
start.record()
10261038

10271039
if use_cuda_graph:
@@ -1039,6 +1051,7 @@ def pure_profile(stream: torch.cuda.Stream, repeat: int):
10391051

10401052
return start.elapsed_time(end) / repeat
10411053

1054+
# warm up, no timing
10421055
for _ in range(self.warmup):
10431056
runner(input_tensor_batches[-1], tactic=tactic, **kwargs)
10441057

tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,18 @@ def _register_fake():
1515

1616
@torch.library.register_fake("trtllm::allreduce")
1717
def allreduce(
18-
input,
19-
residual,
20-
norm_weight,
21-
scale,
22-
bias,
23-
workspace,
24-
group,
25-
strategy,
26-
op,
27-
eps,
28-
trigger_completion_at_end,
29-
):
18+
input: torch.Tensor,
19+
residual: Optional[torch.Tensor],
20+
norm_weight: Optional[torch.Tensor],
21+
scale: Optional[torch.Tensor],
22+
bias: Optional[torch.Tensor],
23+
workspace: Optional[torch.Tensor],
24+
group: List[int],
25+
strategy: int,
26+
op: int,
27+
eps: float,
28+
trigger_completion_at_end: bool,
29+
) -> List[torch.Tensor]:
3030
from tensorrt_llm.functional import AllReduceFusionOp
3131
if op == int(AllReduceFusionOp.NONE):
3232
return [torch.empty_like(input)]
@@ -61,19 +61,19 @@ def allreduce(
6161

6262
@torch.library.register_fake("trtllm::allreduce_pg")
6363
def _(
64-
input,
65-
residual,
66-
norm_weight,
67-
scale,
68-
bias,
69-
workspace,
70-
group,
71-
rank,
64+
input: torch.Tensor,
65+
residual: Optional[torch.Tensor],
66+
norm_weight: Optional[torch.Tensor],
67+
scale: Optional[torch.Tensor],
68+
bias: Optional[torch.Tensor],
69+
workspace: Optional[torch.Tensor],
70+
group: List[int],
71+
rank: int,
7272
pg,
73-
strategy,
74-
op,
75-
eps,
76-
trigger_completion_at_end,
73+
strategy: int,
74+
op: int,
75+
eps: float,
76+
trigger_completion_at_end: bool,
7777
):
7878
return allreduce(input, residual, norm_weight, scale, bias, workspace,
7979
group, strategy, op, eps, trigger_completion_at_end)

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils
99
from tensorrt_llm import deep_gemm
1010
from tensorrt_llm._utils import get_sm_version
11+
from tensorrt_llm.functional import AllReduceFusionOp, AllReduceStrategy
1112
from tensorrt_llm.logger import logger
13+
from tensorrt_llm.plugin.plugin import CustomAllReduceHelper
1214

1315
from ..autotuner import (AutoTuner, ConstraintSpec, DistributedTuningStrategy,
1416
DynamicTensorSpec, OptimizationProfile, TunableRunner,
@@ -1580,6 +1582,173 @@ def _(
15801582
return x.new_empty((b, d), dtype=o_dtype)
15811583

15821584

1585+
class AllReduceRunner(TunableRunner):
1586+
tuning_config = TuningConfig(
1587+
dynamic_tensor_specs=(DynamicTensorSpec(
1588+
0, 0, get_last_power_of_2_num_tokens_buckets(8192),
1589+
last_positive_power_of_2), ),
1590+
constraint_specs=(ConstraintSpec(1, 0, lambda shapes: shapes[0][0]), ),
1591+
distributed_tuning_strategy=DistributedTuningStrategy.MERGE,
1592+
)
1593+
1594+
def __init__(
1595+
self,
1596+
tp_size: int,
1597+
group: List[int],
1598+
op: int,
1599+
eps: float,
1600+
trigger_completion_at_end: bool,
1601+
):
1602+
self.tp_size = tp_size
1603+
self.op = op
1604+
self.group = group
1605+
self.eps = eps
1606+
self.trigger_completion_at_end = trigger_completion_at_end
1607+
1608+
def unique_id(self):
1609+
return (
1610+
self.tp_size,
1611+
self.op,
1612+
)
1613+
1614+
def get_valid_tactics(
1615+
self,
1616+
inputs: List[torch.Tensor],
1617+
profile: OptimizationProfile,
1618+
**kwargs,
1619+
) -> List[int]:
1620+
valid_strategies = [
1621+
# TODO: NCCL_SYMMETRIC will cause hang during tuning process
1622+
# AllReduceStrategy.NCCL_SYMMETRIC.value,
1623+
AllReduceStrategy.NCCL.value,
1624+
]
1625+
# Fallback in allreduceOp is set to NCCL_SYMMETRIC as default
1626+
# So we need to check if the workspace size is too large to avoid hanging.
1627+
workspace_size = inputs[0].numel() * inputs[0].element_size()
1628+
max_workspace_size = CustomAllReduceHelper.max_workspace_size_auto(
1629+
self.tp_size,
1630+
support_deterministic=False,
1631+
)
1632+
if workspace_size > max_workspace_size:
1633+
return valid_strategies
1634+
1635+
valid_strategies.append(AllReduceStrategy.ONESHOT.value)
1636+
1637+
# Additional restrictions for TWOSHOT strategy
1638+
if inputs[0].shape[0] >= self.tp_size:
1639+
valid_strategies.append(AllReduceStrategy.TWOSHOT.value)
1640+
1641+
return valid_strategies
1642+
1643+
def forward(
1644+
self,
1645+
inputs: List[torch.Tensor],
1646+
tactic: int = -1,
1647+
) -> torch.Tensor:
1648+
input, residual, norm_weight, scale, bias, workspace = inputs
1649+
if tactic == -1:
1650+
# TODO: Use NCCL instead of NCCL_SYMMETRIC to avoid hanging during tuning process
1651+
tactic = AllReduceStrategy.NCCL.value
1652+
1653+
return torch.ops.trtllm.allreduce(
1654+
input,
1655+
residual,
1656+
norm_weight,
1657+
scale,
1658+
bias,
1659+
workspace,
1660+
self.group,
1661+
tactic,
1662+
self.op,
1663+
self.eps,
1664+
self.trigger_completion_at_end,
1665+
)
1666+
1667+
1668+
@torch.library.custom_op("trtllm::tunable_allreduce", mutates_args=())
1669+
def tunable_allreduce(
1670+
input: torch.Tensor,
1671+
residual: Optional[torch.Tensor],
1672+
norm_weight: Optional[torch.Tensor],
1673+
scale: Optional[torch.Tensor],
1674+
bias: Optional[torch.Tensor],
1675+
workspace: Optional[torch.Tensor],
1676+
group: List[int],
1677+
op: int,
1678+
eps: float,
1679+
tp_size: int,
1680+
trigger_completion_at_end: bool,
1681+
) -> List[torch.Tensor]:
1682+
1683+
tuner = AutoTuner.get()
1684+
1685+
allreduce_runner = AllReduceRunner(
1686+
tp_size,
1687+
group,
1688+
op,
1689+
eps,
1690+
trigger_completion_at_end,
1691+
)
1692+
1693+
_, best_tactic = tuner.choose_one(
1694+
"trtllm::tunable_allreduce::allreduce",
1695+
[allreduce_runner],
1696+
AllReduceRunner.tuning_config,
1697+
[input, residual, norm_weight, scale, bias, workspace],
1698+
)
1699+
1700+
return allreduce_runner(
1701+
[input, residual, norm_weight, scale, bias, workspace],
1702+
tactic=best_tactic,
1703+
)
1704+
1705+
1706+
@tunable_allreduce.register_fake
1707+
def _(
1708+
input: torch.Tensor,
1709+
residual: Optional[torch.Tensor],
1710+
norm_weight: Optional[torch.Tensor],
1711+
scale: Optional[torch.Tensor],
1712+
bias: Optional[torch.Tensor],
1713+
workspace: Optional[torch.Tensor],
1714+
group: List[int],
1715+
op: int,
1716+
eps: float,
1717+
tp_size: int,
1718+
trigger_completion_at_end: bool,
1719+
) -> List[torch.Tensor]:
1720+
if op == int(AllReduceFusionOp.NONE):
1721+
return [torch.empty_like(input)]
1722+
elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM):
1723+
norm_out = torch.empty_like(input)
1724+
residual_out = torch.empty_like(input)
1725+
return [norm_out, residual_out]
1726+
elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8):
1727+
quant_out = torch.empty_like(input, dtype=torch.float8_e4m3fn)
1728+
residual_out = torch.empty_like(input)
1729+
return [quant_out, residual_out]
1730+
elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_FP8):
1731+
norm_out = torch.empty_like(input)
1732+
quant_out = torch.empty_like(input, dtype=torch.float8_e4m3fn)
1733+
residual_out = torch.empty_like(input)
1734+
return [norm_out, quant_out, residual_out]
1735+
elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4):
1736+
fp4_shape, scale_shape = fp4_utils.get_fp4_shape(input.shape, 16)
1737+
quant_fp4 = input.new_empty(fp4_shape, dtype=torch.uint8)
1738+
scale_fp4 = input.new_empty(scale_shape, dtype=torch.uint8)
1739+
residual_out = torch.empty_like(input)
1740+
return [quant_fp4, scale_fp4, residual_out]
1741+
elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4):
1742+
fp4_shape, scale_shape = fp4_utils.get_fp4_shape(input.shape, 16)
1743+
quant_fp4 = input.new_empty(fp4_shape, dtype=torch.uint8)
1744+
scale_fp4 = input.new_empty(scale_shape, dtype=torch.uint8)
1745+
norm_out = torch.empty_like(input)
1746+
residual_out = torch.empty_like(input)
1747+
return [norm_out, quant_fp4, scale_fp4, residual_out]
1748+
else:
1749+
return [torch.empty_like(input)]
1750+
1751+
15831752
def get_event(event_idx: int):
15841753
from ..utils import get_model_extra_attrs
15851754
extra_attrs = get_model_extra_attrs()

0 commit comments

Comments
 (0)