Skip to content

Commit c6c1a96

Browse files
committed
restore
Signed-off-by: Ludwig Schneider <lschneider@nvidia.com>
1 parent 9fcae0f commit c6c1a96

File tree

1 file changed

+254
-5
lines changed

1 file changed

+254
-5
lines changed

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 254 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils
99
from tensorrt_llm import deep_gemm
1010
from tensorrt_llm._utils import get_sm_version
11+
from tensorrt_llm.functional import AllReduceFusionOp, AllReduceStrategy
1112
from tensorrt_llm.logger import logger
13+
from tensorrt_llm.plugin.plugin import CustomAllReduceHelper
1214

1315
from ..autotuner import (AutoTuner, ConstraintSpec, DistributedTuningStrategy,
1416
DynamicTensorSpec, OptimizationProfile, TunableRunner,
@@ -693,6 +695,14 @@ def _(
693695

694696
class NVFP4GemmUnifiedRunner(TunableRunner):
695697
runner_dict = dict()
698+
tuning_config = TuningConfig(
699+
dynamic_tensor_specs=(DynamicTensorSpec(
700+
0, 0, get_last_power_of_2_num_tokens_buckets,
701+
last_positive_power_of_2), ),
702+
constraint_specs=(ConstraintSpec(2, 0, fp4_scale_infer_shape), ),
703+
# nested tuning should always be independent
704+
distributed_tuning_strategy=DistributedTuningStrategy.INDEPENDENT,
705+
)
696706

697707
def __init__(self, to_userbuffers: bool, output_dtype: torch.dtype,
698708
allowed_backends: List[str]):
@@ -943,7 +953,7 @@ def nvfp4_gemm(
943953
_, best_tactic = tuner.choose_one(
944954
"trtllm::nvfp4_gemm::gemm",
945955
[runner],
946-
FP4GemmRunner.
956+
NVFP4GemmUnifiedRunner.
947957
tuning_config, # All runners use the same tuning_config
948958
[act_fp4, weight, act_sf, weight_scale, alpha],
949959
)
@@ -1319,7 +1329,7 @@ def _(
13191329

13201330
class FinegrainedMixedDtypeGemm(TunableRunner):
13211331
_runner_dict = dict()
1322-
MAX_SUPPORTED_SM_VERSION = 90
1332+
MAX_SUPPORTED_SM_VERSION = 103
13231333

13241334
def __init__(self, activation_dtype: torch.dtype, output_dtype: torch.dtype,
13251335
quant_mode: int):
@@ -1354,7 +1364,7 @@ def forward(self,
13541364

13551365
if get_sm_version() > self.MAX_SUPPORTED_SM_VERSION:
13561366
raise ValueError(
1357-
f"SM version {get_sm_version()} is not supported for W4A16 GEMM"
1367+
f"SM version {get_sm_version()} is not supported for W4A16/W4A8 finegrained mixed dtype GEMM"
13581368
)
13591369

13601370
activation, weights_packed, scales = inputs
@@ -1433,7 +1443,7 @@ def _(
14331443
return input.new_empty((M, N), dtype=output_dtype)
14341444

14351445

1436-
def fp8_swap_ab_gen_tuning_buckets(x: int):
1446+
def deep_gemm_gen_tuning_buckets(x: int):
14371447
buckets = tuple(range(8, 128, 8))
14381448
if x >= 128:
14391449
buckets += tuple(range(128, x, 128))
@@ -1443,7 +1453,7 @@ def fp8_swap_ab_gen_tuning_buckets(x: int):
14431453
class fp8SwapABGemmRunner(TunableRunner):
14441454
tuning_config = TuningConfig(
14451455
dynamic_tensor_specs=(DynamicTensorSpec(
1446-
0, 0, fp8_swap_ab_gen_tuning_buckets), ),
1456+
0, 0, deep_gemm_gen_tuning_buckets), ),
14471457
tune_max_num_tokens=4096,
14481458
)
14491459

@@ -1528,6 +1538,78 @@ def _(
15281538
return input.new_empty((input.size(0), weight.size(0)), dtype=output_dtype)
15291539

15301540

1541+
# The runner is used to trigger deepgemm jit during autotune.
1542+
class Fp8BlockScalingGemmRunner(TunableRunner):
1543+
tuning_config = TuningConfig(
1544+
dynamic_tensor_specs=(DynamicTensorSpec(
1545+
0, 0, deep_gemm_gen_tuning_buckets), ),
1546+
tune_max_num_tokens=4096,
1547+
)
1548+
1549+
def get_valid_tactics(
1550+
self,
1551+
inputs: List[torch.Tensor],
1552+
profile: OptimizationProfile,
1553+
) -> List[int]:
1554+
return [0]
1555+
1556+
def forward(
1557+
self,
1558+
inputs: List[torch.Tensor],
1559+
tactic: int = -1,
1560+
) -> torch.Tensor:
1561+
a, b, a_scale, b_scale = inputs
1562+
return torch.ops.trtllm.fp8_block_scaling_gemm_impl(
1563+
a, b, a_scale, b_scale)
1564+
1565+
1566+
def get_fp8_block_scaling_gemm_constraint_spec():
1567+
# The implementation aligns with the fp8_quantize_1x128 custom op.
1568+
def fp8_quantize_1x128_sm90_constrant(inputs: List[List[int]]):
1569+
pad_m = fp4_utils.pad_up(inputs[0][0], 4)
1570+
blocked_n = (inputs[0][1] + 127) // 128
1571+
return fp4_utils.pad_up(pad_m * blocked_n * 4, 128) // 4
1572+
1573+
if get_sm_version() >= 100:
1574+
return (ConstraintSpec(2, 1, lambda inputs: inputs[0][0]), )
1575+
else:
1576+
return (ConstraintSpec(2, 0, fp8_quantize_1x128_sm90_constrant), )
1577+
1578+
1579+
@torch.library.custom_op("trtllm::fp8_block_scaling_gemm", mutates_args=())
1580+
def fp8_block_scaling_gemm(
1581+
a: torch.Tensor,
1582+
b: torch.Tensor,
1583+
a_scale: torch.Tensor,
1584+
b_scale: torch.Tensor,
1585+
tune_max_num_tokens: int = 4096,
1586+
) -> torch.Tensor:
1587+
tuner = AutoTuner.get()
1588+
fp8_block_scaling_gemm_runner = Fp8BlockScalingGemmRunner()
1589+
Fp8BlockScalingGemmRunner.tuning_config.tune_max_num_tokens = tune_max_num_tokens
1590+
1591+
Fp8BlockScalingGemmRunner.tuning_config.constraint_specs = get_fp8_block_scaling_gemm_constraint_spec(
1592+
)
1593+
1594+
_, best_tactic = tuner.choose_one(
1595+
"trtllm::fp8_block_scaling_gemm",
1596+
[fp8_block_scaling_gemm_runner],
1597+
Fp8BlockScalingGemmRunner.tuning_config,
1598+
[a, b, a_scale, b_scale],
1599+
)
1600+
return fp8_block_scaling_gemm_runner(
1601+
inputs=[a, b, a_scale, b_scale],
1602+
tactic=best_tactic,
1603+
)
1604+
1605+
1606+
@fp8_block_scaling_gemm.register_fake
1607+
def _(a, b, a_scale, b_scale, tune_max_num_tokens=4096):
1608+
m = a.shape[0]
1609+
n = b.shape[0]
1610+
return a.new_empty((m, n), dtype=torch.bfloat16)
1611+
1612+
15311613
@torch.library.custom_op("trtllm::silu_and_mul", mutates_args=())
15321614
def silu_and_mul(x: torch.Tensor,
15331615
scale: Optional[torch.Tensor] = None,
@@ -1572,6 +1654,173 @@ def _(
15721654
return x.new_empty((b, d), dtype=o_dtype)
15731655

15741656

1657+
class AllReduceRunner(TunableRunner):
1658+
tuning_config = TuningConfig(
1659+
dynamic_tensor_specs=(DynamicTensorSpec(
1660+
0, 0, get_last_power_of_2_num_tokens_buckets(8192),
1661+
last_positive_power_of_2), ),
1662+
constraint_specs=(ConstraintSpec(1, 0, lambda shapes: shapes[0][0]), ),
1663+
distributed_tuning_strategy=DistributedTuningStrategy.MERGE,
1664+
)
1665+
1666+
def __init__(
1667+
self,
1668+
tp_size: int,
1669+
group: List[int],
1670+
op: int,
1671+
eps: float,
1672+
trigger_completion_at_end: bool,
1673+
):
1674+
self.tp_size = tp_size
1675+
self.op = op
1676+
self.group = group
1677+
self.eps = eps
1678+
self.trigger_completion_at_end = trigger_completion_at_end
1679+
1680+
def unique_id(self):
1681+
return (
1682+
self.tp_size,
1683+
self.op,
1684+
)
1685+
1686+
def get_valid_tactics(
1687+
self,
1688+
inputs: List[torch.Tensor],
1689+
profile: OptimizationProfile,
1690+
**kwargs,
1691+
) -> List[int]:
1692+
valid_strategies = [
1693+
# TODO: NCCL_SYMMETRIC will cause hang during tuning process
1694+
# AllReduceStrategy.NCCL_SYMMETRIC.value,
1695+
AllReduceStrategy.NCCL.value,
1696+
]
1697+
# Fallback in allreduceOp is set to NCCL_SYMMETRIC as default
1698+
# So we need to check if the workspace size is too large to avoid hanging.
1699+
workspace_size = inputs[0].numel() * inputs[0].element_size()
1700+
max_workspace_size = CustomAllReduceHelper.max_workspace_size_auto(
1701+
self.tp_size,
1702+
support_deterministic=False,
1703+
)
1704+
if workspace_size > max_workspace_size:
1705+
return valid_strategies
1706+
1707+
valid_strategies.append(AllReduceStrategy.ONESHOT.value)
1708+
1709+
# Additional restrictions for TWOSHOT strategy
1710+
if inputs[0].shape[0] >= self.tp_size:
1711+
valid_strategies.append(AllReduceStrategy.TWOSHOT.value)
1712+
1713+
return valid_strategies
1714+
1715+
def forward(
1716+
self,
1717+
inputs: List[torch.Tensor],
1718+
tactic: int = -1,
1719+
) -> torch.Tensor:
1720+
input, residual, norm_weight, scale, bias, workspace = inputs
1721+
if tactic == -1:
1722+
# TODO: Use NCCL instead of NCCL_SYMMETRIC to avoid hanging during tuning process
1723+
tactic = AllReduceStrategy.NCCL.value
1724+
1725+
return torch.ops.trtllm.allreduce(
1726+
input,
1727+
residual,
1728+
norm_weight,
1729+
scale,
1730+
bias,
1731+
workspace,
1732+
self.group,
1733+
tactic,
1734+
self.op,
1735+
self.eps,
1736+
self.trigger_completion_at_end,
1737+
)
1738+
1739+
1740+
@torch.library.custom_op("trtllm::tunable_allreduce", mutates_args=())
1741+
def tunable_allreduce(
1742+
input: torch.Tensor,
1743+
residual: Optional[torch.Tensor],
1744+
norm_weight: Optional[torch.Tensor],
1745+
scale: Optional[torch.Tensor],
1746+
bias: Optional[torch.Tensor],
1747+
workspace: Optional[torch.Tensor],
1748+
group: List[int],
1749+
strategy: int,
1750+
op: int,
1751+
eps: float,
1752+
trigger_completion_at_end: bool,
1753+
) -> List[torch.Tensor]:
1754+
1755+
tuner = AutoTuner.get()
1756+
1757+
allreduce_runner = AllReduceRunner(
1758+
len(group),
1759+
group,
1760+
op,
1761+
eps,
1762+
trigger_completion_at_end,
1763+
)
1764+
1765+
_, best_tactic = tuner.choose_one(
1766+
"trtllm::tunable_allreduce::allreduce",
1767+
[allreduce_runner],
1768+
AllReduceRunner.tuning_config,
1769+
[input, residual, norm_weight, scale, bias, workspace],
1770+
)
1771+
1772+
return allreduce_runner(
1773+
[input, residual, norm_weight, scale, bias, workspace],
1774+
tactic=best_tactic,
1775+
)
1776+
1777+
1778+
@tunable_allreduce.register_fake
1779+
def _(
1780+
input: torch.Tensor,
1781+
residual: Optional[torch.Tensor],
1782+
norm_weight: Optional[torch.Tensor],
1783+
scale: Optional[torch.Tensor],
1784+
bias: Optional[torch.Tensor],
1785+
workspace: Optional[torch.Tensor],
1786+
group: List[int],
1787+
strategy: int,
1788+
op: int,
1789+
eps: float,
1790+
trigger_completion_at_end: bool,
1791+
) -> List[torch.Tensor]:
1792+
if op == int(AllReduceFusionOp.NONE):
1793+
return [torch.empty_like(input)]
1794+
elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM):
1795+
norm_out = torch.empty_like(input)
1796+
residual_out = torch.empty_like(input)
1797+
return [norm_out, residual_out]
1798+
elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8):
1799+
quant_out = torch.empty_like(input, dtype=torch.float8_e4m3fn)
1800+
residual_out = torch.empty_like(input)
1801+
return [quant_out, residual_out]
1802+
elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_FP8):
1803+
norm_out = torch.empty_like(input)
1804+
quant_out = torch.empty_like(input, dtype=torch.float8_e4m3fn)
1805+
residual_out = torch.empty_like(input)
1806+
return [norm_out, quant_out, residual_out]
1807+
elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4):
1808+
fp4_shape, scale_shape = fp4_utils.get_fp4_shape(input.shape, 16)
1809+
quant_fp4 = input.new_empty(fp4_shape, dtype=torch.uint8)
1810+
scale_fp4 = input.new_empty(scale_shape, dtype=torch.uint8)
1811+
residual_out = torch.empty_like(input)
1812+
return [quant_fp4, scale_fp4, residual_out]
1813+
elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4):
1814+
fp4_shape, scale_shape = fp4_utils.get_fp4_shape(input.shape, 16)
1815+
quant_fp4 = input.new_empty(fp4_shape, dtype=torch.uint8)
1816+
scale_fp4 = input.new_empty(scale_shape, dtype=torch.uint8)
1817+
norm_out = torch.empty_like(input)
1818+
residual_out = torch.empty_like(input)
1819+
return [norm_out, quant_fp4, scale_fp4, residual_out]
1820+
else:
1821+
return [torch.empty_like(input)]
1822+
1823+
15751824
def get_event(event_idx: int):
15761825
from ..utils import get_model_extra_attrs
15771826
extra_attrs = get_model_extra_attrs()

0 commit comments

Comments
 (0)