88import tensorrt_llm .quantization .utils .fp8_utils as fp8_utils
99from tensorrt_llm import deep_gemm
1010from tensorrt_llm ._utils import get_sm_version
11+ from tensorrt_llm .functional import AllReduceFusionOp , AllReduceStrategy
1112from tensorrt_llm .logger import logger
13+ from tensorrt_llm .plugin .plugin import CustomAllReduceHelper
1214
1315from ..autotuner import (AutoTuner , ConstraintSpec , DistributedTuningStrategy ,
1416 DynamicTensorSpec , OptimizationProfile , TunableRunner ,
@@ -693,6 +695,14 @@ def _(
693695
694696class NVFP4GemmUnifiedRunner (TunableRunner ):
695697 runner_dict = dict ()
698+ tuning_config = TuningConfig (
699+ dynamic_tensor_specs = (DynamicTensorSpec (
700+ 0 , 0 , get_last_power_of_2_num_tokens_buckets ,
701+ last_positive_power_of_2 ), ),
702+ constraint_specs = (ConstraintSpec (2 , 0 , fp4_scale_infer_shape ), ),
703+ # nested tuning should always be independent
704+ distributed_tuning_strategy = DistributedTuningStrategy .INDEPENDENT ,
705+ )
696706
697707 def __init__ (self , to_userbuffers : bool , output_dtype : torch .dtype ,
698708 allowed_backends : List [str ]):
@@ -943,7 +953,7 @@ def nvfp4_gemm(
943953 _ , best_tactic = tuner .choose_one (
944954 "trtllm::nvfp4_gemm::gemm" ,
945955 [runner ],
946- FP4GemmRunner .
956+ NVFP4GemmUnifiedRunner .
947957 tuning_config , # All runners use the same tuning_config
948958 [act_fp4 , weight , act_sf , weight_scale , alpha ],
949959 )
@@ -1319,7 +1329,7 @@ def _(
13191329
13201330class FinegrainedMixedDtypeGemm (TunableRunner ):
13211331 _runner_dict = dict ()
1322- MAX_SUPPORTED_SM_VERSION = 90
1332+ MAX_SUPPORTED_SM_VERSION = 103
13231333
13241334 def __init__ (self , activation_dtype : torch .dtype , output_dtype : torch .dtype ,
13251335 quant_mode : int ):
@@ -1354,7 +1364,7 @@ def forward(self,
13541364
13551365 if get_sm_version () > self .MAX_SUPPORTED_SM_VERSION :
13561366 raise ValueError (
1357- f"SM version { get_sm_version ()} is not supported for W4A16 GEMM"
1367+ f"SM version { get_sm_version ()} is not supported for W4A16/W4A8 finegrained mixed dtype GEMM"
13581368 )
13591369
13601370 activation , weights_packed , scales = inputs
@@ -1433,7 +1443,7 @@ def _(
14331443 return input .new_empty ((M , N ), dtype = output_dtype )
14341444
14351445
1436- def fp8_swap_ab_gen_tuning_buckets (x : int ):
1446+ def deep_gemm_gen_tuning_buckets (x : int ):
14371447 buckets = tuple (range (8 , 128 , 8 ))
14381448 if x >= 128 :
14391449 buckets += tuple (range (128 , x , 128 ))
@@ -1443,7 +1453,7 @@ def fp8_swap_ab_gen_tuning_buckets(x: int):
14431453class fp8SwapABGemmRunner (TunableRunner ):
14441454 tuning_config = TuningConfig (
14451455 dynamic_tensor_specs = (DynamicTensorSpec (
1446- 0 , 0 , fp8_swap_ab_gen_tuning_buckets ), ),
1456+ 0 , 0 , deep_gemm_gen_tuning_buckets ), ),
14471457 tune_max_num_tokens = 4096 ,
14481458 )
14491459
@@ -1528,6 +1538,78 @@ def _(
15281538 return input .new_empty ((input .size (0 ), weight .size (0 )), dtype = output_dtype )
15291539
15301540
1541+ # The runner is used to trigger deepgemm jit during autotune.
1542+ class Fp8BlockScalingGemmRunner (TunableRunner ):
1543+ tuning_config = TuningConfig (
1544+ dynamic_tensor_specs = (DynamicTensorSpec (
1545+ 0 , 0 , deep_gemm_gen_tuning_buckets ), ),
1546+ tune_max_num_tokens = 4096 ,
1547+ )
1548+
1549+ def get_valid_tactics (
1550+ self ,
1551+ inputs : List [torch .Tensor ],
1552+ profile : OptimizationProfile ,
1553+ ) -> List [int ]:
1554+ return [0 ]
1555+
1556+ def forward (
1557+ self ,
1558+ inputs : List [torch .Tensor ],
1559+ tactic : int = - 1 ,
1560+ ) -> torch .Tensor :
1561+ a , b , a_scale , b_scale = inputs
1562+ return torch .ops .trtllm .fp8_block_scaling_gemm_impl (
1563+ a , b , a_scale , b_scale )
1564+
1565+
1566+ def get_fp8_block_scaling_gemm_constraint_spec ():
1567+ # The implementation aligns with the fp8_quantize_1x128 custom op.
1568+ def fp8_quantize_1x128_sm90_constrant (inputs : List [List [int ]]):
1569+ pad_m = fp4_utils .pad_up (inputs [0 ][0 ], 4 )
1570+ blocked_n = (inputs [0 ][1 ] + 127 ) // 128
1571+ return fp4_utils .pad_up (pad_m * blocked_n * 4 , 128 ) // 4
1572+
1573+ if get_sm_version () >= 100 :
1574+ return (ConstraintSpec (2 , 1 , lambda inputs : inputs [0 ][0 ]), )
1575+ else :
1576+ return (ConstraintSpec (2 , 0 , fp8_quantize_1x128_sm90_constrant ), )
1577+
1578+
1579+ @torch .library .custom_op ("trtllm::fp8_block_scaling_gemm" , mutates_args = ())
1580+ def fp8_block_scaling_gemm (
1581+ a : torch .Tensor ,
1582+ b : torch .Tensor ,
1583+ a_scale : torch .Tensor ,
1584+ b_scale : torch .Tensor ,
1585+ tune_max_num_tokens : int = 4096 ,
1586+ ) -> torch .Tensor :
1587+ tuner = AutoTuner .get ()
1588+ fp8_block_scaling_gemm_runner = Fp8BlockScalingGemmRunner ()
1589+ Fp8BlockScalingGemmRunner .tuning_config .tune_max_num_tokens = tune_max_num_tokens
1590+
1591+ Fp8BlockScalingGemmRunner .tuning_config .constraint_specs = get_fp8_block_scaling_gemm_constraint_spec (
1592+ )
1593+
1594+ _ , best_tactic = tuner .choose_one (
1595+ "trtllm::fp8_block_scaling_gemm" ,
1596+ [fp8_block_scaling_gemm_runner ],
1597+ Fp8BlockScalingGemmRunner .tuning_config ,
1598+ [a , b , a_scale , b_scale ],
1599+ )
1600+ return fp8_block_scaling_gemm_runner (
1601+ inputs = [a , b , a_scale , b_scale ],
1602+ tactic = best_tactic ,
1603+ )
1604+
1605+
1606+ @fp8_block_scaling_gemm .register_fake
1607+ def _ (a , b , a_scale , b_scale , tune_max_num_tokens = 4096 ):
1608+ m = a .shape [0 ]
1609+ n = b .shape [0 ]
1610+ return a .new_empty ((m , n ), dtype = torch .bfloat16 )
1611+
1612+
15311613@torch .library .custom_op ("trtllm::silu_and_mul" , mutates_args = ())
15321614def silu_and_mul (x : torch .Tensor ,
15331615 scale : Optional [torch .Tensor ] = None ,
@@ -1572,6 +1654,173 @@ def _(
15721654 return x .new_empty ((b , d ), dtype = o_dtype )
15731655
15741656
1657+ class AllReduceRunner (TunableRunner ):
1658+ tuning_config = TuningConfig (
1659+ dynamic_tensor_specs = (DynamicTensorSpec (
1660+ 0 , 0 , get_last_power_of_2_num_tokens_buckets (8192 ),
1661+ last_positive_power_of_2 ), ),
1662+ constraint_specs = (ConstraintSpec (1 , 0 , lambda shapes : shapes [0 ][0 ]), ),
1663+ distributed_tuning_strategy = DistributedTuningStrategy .MERGE ,
1664+ )
1665+
1666+ def __init__ (
1667+ self ,
1668+ tp_size : int ,
1669+ group : List [int ],
1670+ op : int ,
1671+ eps : float ,
1672+ trigger_completion_at_end : bool ,
1673+ ):
1674+ self .tp_size = tp_size
1675+ self .op = op
1676+ self .group = group
1677+ self .eps = eps
1678+ self .trigger_completion_at_end = trigger_completion_at_end
1679+
1680+ def unique_id (self ):
1681+ return (
1682+ self .tp_size ,
1683+ self .op ,
1684+ )
1685+
1686+ def get_valid_tactics (
1687+ self ,
1688+ inputs : List [torch .Tensor ],
1689+ profile : OptimizationProfile ,
1690+ ** kwargs ,
1691+ ) -> List [int ]:
1692+ valid_strategies = [
1693+ # TODO: NCCL_SYMMETRIC will cause hang during tuning process
1694+ # AllReduceStrategy.NCCL_SYMMETRIC.value,
1695+ AllReduceStrategy .NCCL .value ,
1696+ ]
1697+ # Fallback in allreduceOp is set to NCCL_SYMMETRIC as default
1698+ # So we need to check if the workspace size is too large to avoid hanging.
1699+ workspace_size = inputs [0 ].numel () * inputs [0 ].element_size ()
1700+ max_workspace_size = CustomAllReduceHelper .max_workspace_size_auto (
1701+ self .tp_size ,
1702+ support_deterministic = False ,
1703+ )
1704+ if workspace_size > max_workspace_size :
1705+ return valid_strategies
1706+
1707+ valid_strategies .append (AllReduceStrategy .ONESHOT .value )
1708+
1709+ # Additional restrictions for TWOSHOT strategy
1710+ if inputs [0 ].shape [0 ] >= self .tp_size :
1711+ valid_strategies .append (AllReduceStrategy .TWOSHOT .value )
1712+
1713+ return valid_strategies
1714+
1715+ def forward (
1716+ self ,
1717+ inputs : List [torch .Tensor ],
1718+ tactic : int = - 1 ,
1719+ ) -> torch .Tensor :
1720+ input , residual , norm_weight , scale , bias , workspace = inputs
1721+ if tactic == - 1 :
1722+ # TODO: Use NCCL instead of NCCL_SYMMETRIC to avoid hanging during tuning process
1723+ tactic = AllReduceStrategy .NCCL .value
1724+
1725+ return torch .ops .trtllm .allreduce (
1726+ input ,
1727+ residual ,
1728+ norm_weight ,
1729+ scale ,
1730+ bias ,
1731+ workspace ,
1732+ self .group ,
1733+ tactic ,
1734+ self .op ,
1735+ self .eps ,
1736+ self .trigger_completion_at_end ,
1737+ )
1738+
1739+
1740+ @torch .library .custom_op ("trtllm::tunable_allreduce" , mutates_args = ())
1741+ def tunable_allreduce (
1742+ input : torch .Tensor ,
1743+ residual : Optional [torch .Tensor ],
1744+ norm_weight : Optional [torch .Tensor ],
1745+ scale : Optional [torch .Tensor ],
1746+ bias : Optional [torch .Tensor ],
1747+ workspace : Optional [torch .Tensor ],
1748+ group : List [int ],
1749+ strategy : int ,
1750+ op : int ,
1751+ eps : float ,
1752+ trigger_completion_at_end : bool ,
1753+ ) -> List [torch .Tensor ]:
1754+
1755+ tuner = AutoTuner .get ()
1756+
1757+ allreduce_runner = AllReduceRunner (
1758+ len (group ),
1759+ group ,
1760+ op ,
1761+ eps ,
1762+ trigger_completion_at_end ,
1763+ )
1764+
1765+ _ , best_tactic = tuner .choose_one (
1766+ "trtllm::tunable_allreduce::allreduce" ,
1767+ [allreduce_runner ],
1768+ AllReduceRunner .tuning_config ,
1769+ [input , residual , norm_weight , scale , bias , workspace ],
1770+ )
1771+
1772+ return allreduce_runner (
1773+ [input , residual , norm_weight , scale , bias , workspace ],
1774+ tactic = best_tactic ,
1775+ )
1776+
1777+
1778+ @tunable_allreduce .register_fake
1779+ def _ (
1780+ input : torch .Tensor ,
1781+ residual : Optional [torch .Tensor ],
1782+ norm_weight : Optional [torch .Tensor ],
1783+ scale : Optional [torch .Tensor ],
1784+ bias : Optional [torch .Tensor ],
1785+ workspace : Optional [torch .Tensor ],
1786+ group : List [int ],
1787+ strategy : int ,
1788+ op : int ,
1789+ eps : float ,
1790+ trigger_completion_at_end : bool ,
1791+ ) -> List [torch .Tensor ]:
1792+ if op == int (AllReduceFusionOp .NONE ):
1793+ return [torch .empty_like (input )]
1794+ elif op == int (AllReduceFusionOp .RESIDUAL_RMS_NORM ):
1795+ norm_out = torch .empty_like (input )
1796+ residual_out = torch .empty_like (input )
1797+ return [norm_out , residual_out ]
1798+ elif op == int (AllReduceFusionOp .RESIDUAL_RMS_NORM_QUANT_FP8 ):
1799+ quant_out = torch .empty_like (input , dtype = torch .float8_e4m3fn )
1800+ residual_out = torch .empty_like (input )
1801+ return [quant_out , residual_out ]
1802+ elif op == int (AllReduceFusionOp .RESIDUAL_RMS_NORM_OUT_QUANT_FP8 ):
1803+ norm_out = torch .empty_like (input )
1804+ quant_out = torch .empty_like (input , dtype = torch .float8_e4m3fn )
1805+ residual_out = torch .empty_like (input )
1806+ return [norm_out , quant_out , residual_out ]
1807+ elif op == int (AllReduceFusionOp .RESIDUAL_RMS_NORM_QUANT_NVFP4 ):
1808+ fp4_shape , scale_shape = fp4_utils .get_fp4_shape (input .shape , 16 )
1809+ quant_fp4 = input .new_empty (fp4_shape , dtype = torch .uint8 )
1810+ scale_fp4 = input .new_empty (scale_shape , dtype = torch .uint8 )
1811+ residual_out = torch .empty_like (input )
1812+ return [quant_fp4 , scale_fp4 , residual_out ]
1813+ elif op == int (AllReduceFusionOp .RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4 ):
1814+ fp4_shape , scale_shape = fp4_utils .get_fp4_shape (input .shape , 16 )
1815+ quant_fp4 = input .new_empty (fp4_shape , dtype = torch .uint8 )
1816+ scale_fp4 = input .new_empty (scale_shape , dtype = torch .uint8 )
1817+ norm_out = torch .empty_like (input )
1818+ residual_out = torch .empty_like (input )
1819+ return [norm_out , quant_fp4 , scale_fp4 , residual_out ]
1820+ else :
1821+ return [torch .empty_like (input )]
1822+
1823+
15751824def get_event (event_idx : int ):
15761825 from ..utils import get_model_extra_attrs
15771826 extra_attrs = get_model_extra_attrs ()
0 commit comments