Skip to content

Commit 872bb3c

Browse files
committed
[TRTLLM-8129][feat] Apply AutoTuner to AllReduce Op for strategy tuning.
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
1 parent 3a5845e commit 872bb3c

File tree

5 files changed

+285
-35
lines changed

5 files changed

+285
-35
lines changed

tensorrt_llm/_torch/autotuner.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import torch
1616

1717
import tensorrt_llm
18+
from tensorrt_llm._utils import mpi_barrier
1819
from tensorrt_llm.bindings.internal.runtime import delay_kernel
1920
from tensorrt_llm.logger import logger
2021

@@ -534,8 +535,6 @@ def __init__(self, warmup=3, repeat=10, stream_delay_micro_secs=1000):
534535
# Add statistics tracking
535536
self.stats = AutoTunerStatistics()
536537

537-
self.profiling_debug = True
538-
539538
@classmethod
540539
def get(cls):
541540
if cls._instance is None:
@@ -745,6 +744,10 @@ def _profile_single_kernel(
745744
are used to ensure accurate timing.
746745
"""
747746
stream = torch.cuda.current_stream()
747+
748+
if self._is_sync_op(runner):
749+
mpi_barrier()
750+
748751
# warm up, no timing
749752
for _ in range(self.warmup):
750753
runner(inputs, tactic=tactic, **kwargs)
@@ -757,6 +760,9 @@ def _profile_single_kernel(
757760
start = torch.cuda.Event(enable_timing=True)
758761
end = torch.cuda.Event(enable_timing=True)
759762

763+
if self._is_sync_op(runner):
764+
mpi_barrier()
765+
760766
start.record(stream=stream)
761767
for _ in range(self.repeat):
762768
runner(inputs, tactic=tactic, **kwargs)
@@ -939,6 +945,9 @@ def _prepare_input_tensors(
939945
tensors.append(tensor)
940946
return tensors
941947

948+
def _is_sync_op(self, runner: TunableRunner) -> bool:
949+
return runner.__class__.__name__ in ["AllReduceRunner"]
950+
942951
def clear_cache(self) -> None:
943952
"""Clear the profiling cache."""
944953
self.profiling_cache.clear()

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils
99
from tensorrt_llm import deep_gemm
1010
from tensorrt_llm._utils import get_sm_version
11+
from tensorrt_llm.functional import AllReduceFusionOp, AllReduceStrategy
1112

1213
from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec,
1314
OptimizationProfile, TunableRunner, TuningConfig)
@@ -1139,6 +1140,172 @@ def _(
11391140
return x.new_empty((b, d), dtype=o_dtype)
11401141

11411142

1143+
class AllReduceRunner(TunableRunner):
1144+
all_support_ops = {
1145+
AllReduceFusionOp.NONE.value,
1146+
AllReduceFusionOp.RESIDUAL_RMS_NORM.value,
1147+
}
1148+
1149+
tuning_config = TuningConfig(
1150+
dynamic_tensor_specs=(DynamicTensorSpec(
1151+
0, 0,
1152+
(8192, 4096, 2048, 1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1),
1153+
last_positive_power_of_2), ),
1154+
constraint_specs=(ConstraintSpec(1, 0, lambda shapes: shapes[0][0]), ),
1155+
)
1156+
1157+
def __init__(
1158+
self,
1159+
tp_size: int,
1160+
group: List[int],
1161+
op: int,
1162+
eps: float,
1163+
trigger_completion_at_end: bool,
1164+
):
1165+
self.tp_size = tp_size
1166+
self.op = op
1167+
self._group = group
1168+
self._eps = eps
1169+
self._trigger_completion_at_end = trigger_completion_at_end
1170+
1171+
def __hash__(self):
1172+
return hash((self.tp_size, self.op))
1173+
1174+
def get_valid_tactics(
1175+
self,
1176+
inputs: List[torch.Tensor],
1177+
profile: OptimizationProfile,
1178+
**kwargs,
1179+
) -> List[int]:
1180+
valid_tactics = [
1181+
AllReduceStrategy.NCCL.value,
1182+
AllReduceStrategy.ONESHOT.value,
1183+
]
1184+
if inputs[0].shape[0] >= self.tp_size:
1185+
valid_tactics.append(AllReduceStrategy.TWOSHOT.value)
1186+
return valid_tactics
1187+
1188+
def forward(
1189+
self,
1190+
inputs: List[torch.Tensor],
1191+
tactic: int = -1,
1192+
) -> torch.Tensor:
1193+
input, residual, norm_weight, scale, bias, workspace = inputs
1194+
if tactic == -1:
1195+
tactic = AllReduceStrategy.NCCL.value
1196+
1197+
torch.ops.trtllm.allreduce(
1198+
input,
1199+
residual,
1200+
norm_weight,
1201+
scale,
1202+
bias,
1203+
workspace,
1204+
self._group,
1205+
tactic,
1206+
self.op,
1207+
self._eps,
1208+
self._trigger_completion_at_end,
1209+
)
1210+
1211+
1212+
@torch.library.custom_op("trtllm::tunable_allreduce", mutates_args=())
1213+
def tunable_allreduce(
1214+
input: torch.Tensor,
1215+
residual: Optional[torch.Tensor],
1216+
norm_weight: Optional[torch.Tensor],
1217+
scale: Optional[torch.Tensor],
1218+
bias: Optional[torch.Tensor],
1219+
workspace: Optional[torch.Tensor],
1220+
group: List[int],
1221+
strategy: int,
1222+
op: int,
1223+
eps: float,
1224+
tp_size: int,
1225+
trigger_completion_at_end: bool,
1226+
) -> List[torch.Tensor]:
1227+
1228+
tuner = AutoTuner.get()
1229+
1230+
allreduce_runner = AllReduceRunner(
1231+
tp_size,
1232+
group,
1233+
op,
1234+
eps,
1235+
trigger_completion_at_end,
1236+
)
1237+
1238+
_, best_tactic = tuner.choose_one(
1239+
"trtllm::tunable_allreduce::allreduce",
1240+
[allreduce_runner],
1241+
AllReduceRunner.tuning_config,
1242+
[input, residual, norm_weight, scale, bias, workspace],
1243+
)
1244+
1245+
if best_tactic == -1:
1246+
best_tactic = AllReduceStrategy.NCCL.value
1247+
1248+
return torch.ops.trtllm.allreduce(
1249+
input,
1250+
residual,
1251+
norm_weight,
1252+
scale,
1253+
bias,
1254+
workspace,
1255+
group,
1256+
best_tactic,
1257+
op,
1258+
eps,
1259+
trigger_completion_at_end,
1260+
)
1261+
1262+
1263+
@tunable_allreduce.register_fake
1264+
def _(
1265+
input: torch.Tensor,
1266+
residual: Optional[torch.Tensor],
1267+
norm_weight: Optional[torch.Tensor],
1268+
scale: Optional[torch.Tensor],
1269+
bias: Optional[torch.Tensor],
1270+
workspace: Optional[torch.Tensor],
1271+
group: List[int],
1272+
strategy: int,
1273+
op: int,
1274+
eps: float,
1275+
trigger_completion_at_end: bool,
1276+
) -> torch.Tensor:
1277+
if op == int(AllReduceFusionOp.NONE):
1278+
return [torch.empty_like(input)]
1279+
elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM):
1280+
norm_out = torch.empty_like(input)
1281+
residual_out = torch.empty_like(input)
1282+
return [norm_out, residual_out]
1283+
elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8):
1284+
quant_out = torch.empty_like(input, dtype=torch.float8_e4m3fn)
1285+
residual_out = torch.empty_like(input)
1286+
return [quant_out, residual_out]
1287+
elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_FP8):
1288+
norm_out = torch.empty_like(input)
1289+
quant_out = torch.empty_like(input, dtype=torch.float8_e4m3fn)
1290+
residual_out = torch.empty_like(input)
1291+
return [norm_out, quant_out, residual_out]
1292+
elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4):
1293+
fp4_shape, scale_shape = fp4_utils.get_fp4_shape(input.shape, 16)
1294+
quant_fp4 = input.new_empty(fp4_shape, dtype=torch.uint8)
1295+
scale_fp4 = input.new_empty(scale_shape, dtype=torch.uint8)
1296+
residual_out = torch.empty_like(input)
1297+
return [quant_fp4, scale_fp4, residual_out]
1298+
elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4):
1299+
fp4_shape, scale_shape = fp4_utils.get_fp4_shape(input.shape, 16)
1300+
quant_fp4 = input.new_empty(fp4_shape, dtype=torch.uint8)
1301+
scale_fp4 = input.new_empty(scale_shape, dtype=torch.uint8)
1302+
norm_out = torch.empty_like(input)
1303+
residual_out = torch.empty_like(input)
1304+
return [norm_out, quant_fp4, scale_fp4, residual_out]
1305+
else:
1306+
return [torch.empty_like(input)]
1307+
1308+
11421309
def get_event(event_idx: int):
11431310
from ..utils import get_model_extra_attrs
11441311
extra_attrs = get_model_extra_attrs()

tensorrt_llm/_torch/distributed/ops.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,6 @@ def __init__(self,
505505
self._disable_mpi = mpi_disabled()
506506

507507
self.all_reduce_op = torch.ops.trtllm.allreduce_pg if self._disable_mpi else torch.ops.trtllm.allreduce
508-
509508
if self.mapping.tp_size > 1:
510509
# When Strategy is UB, it is guaranteed that the workspace is not used.
511510
if self.strategy != AllReduceStrategy.UB:
@@ -574,6 +573,7 @@ def forward(
574573
input = input.contiguous() # Underlying op requires contiguous input
575574

576575
allreduce_strategy = self.strategy
576+
577577
if all_reduce_params is None:
578578
all_reduce_params = AllReduceParams()
579579

@@ -598,21 +598,38 @@ def forward(
598598
"pg": pg.boxed(),
599599
}
600600

601-
output = self.all_reduce_op(
602-
input=input,
603-
residual=all_reduce_params.residual,
604-
norm_weight=all_reduce_params.norm_weight,
605-
scale=all_reduce_params.scale,
606-
bias=all_reduce_params.bias,
607-
workspace=self.workspace,
608-
group=self.mapping.tp_group,
609-
strategy=allreduce_strategy,
610-
op=all_reduce_params.fusion_op,
611-
eps=all_reduce_params.eps,
612-
trigger_completion_at_end=all_reduce_params.
613-
trigger_completion_at_end,
614-
**additional_args,
615-
)
601+
if self.strategy == AllReduceStrategy.AUTOTUNE:
602+
output = torch.ops.trtllm.tunable_allreduce(
603+
input=input,
604+
residual=all_reduce_params.residual,
605+
norm_weight=all_reduce_params.norm_weight,
606+
scale=all_reduce_params.scale,
607+
bias=all_reduce_params.bias,
608+
workspace=self.workspace,
609+
group=self.mapping.tp_group,
610+
strategy=allreduce_strategy,
611+
op=all_reduce_params.fusion_op,
612+
eps=all_reduce_params.eps,
613+
tp_size=self.mapping.tp_size,
614+
trigger_completion_at_end=all_reduce_params.
615+
trigger_completion_at_end,
616+
)
617+
else:
618+
output = self.all_reduce_op(
619+
input=input,
620+
residual=all_reduce_params.residual,
621+
norm_weight=all_reduce_params.norm_weight,
622+
scale=all_reduce_params.scale,
623+
bias=all_reduce_params.bias,
624+
workspace=self.workspace,
625+
group=self.mapping.tp_group,
626+
strategy=allreduce_strategy,
627+
op=all_reduce_params.fusion_op,
628+
eps=all_reduce_params.eps,
629+
trigger_completion_at_end=all_reduce_params.
630+
trigger_completion_at_end,
631+
**additional_args,
632+
)
616633

617634
return output if len(output) > 1 else output[0]
618635

tensorrt_llm/functional.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3883,6 +3883,7 @@ class AllReduceStrategy(IntEnum):
38833883
LOWPRECISION = 6
38843884
MNNVL = 7
38853885
NCCL_SYMMETRIC = 8
3886+
AUTOTUNE = 9
38863887

38873888

38883889
class AllReduceFusionOp(IntEnum):

0 commit comments

Comments
 (0)