11from operator import getitem
2- from typing import List , Optional
2+ from typing import Callable , List , Optional
33
44import torch
55from torch ._inductor .pattern_matcher import (MULTIPLE , CallFunction , Ignored ,
1414from tensorrt_llm .mapping import Mapping
1515
1616
17- def register_ar_residual_norm (custom_pass : PatternMatcherPass ,
18- mapping : Mapping ):
17+ def register_ar_residual_norm (custom_pass : PatternMatcherPass , mapping : Mapping ,
18+ allreduce_func : Callable ):
1919 residual_key = KeywordArg ("residual" )
2020 trtllm_allreduce_default = CallFunction (
21- torch . ops . trtllm . allreduce . default , KeywordArg ("input" ), None , None ,
22- None , None , KeywordArg ("workspace" ), mapping .tp_group ,
23- KeywordArg ( "strategy" ), int (AllReduceFusionOp .NONE ), Ignored (),
21+ allreduce_func . default , KeywordArg ("input" ), None , None , None , None ,
22+ KeywordArg ("workspace" ), mapping .tp_group , KeywordArg ( "strategy" ) ,
23+ int (AllReduceFusionOp .NONE ), Ignored (),
2424 KeywordArg ("trigger_completion_at_end" ))
2525 getitem_x = CallFunction (getitem , trtllm_allreduce_default , 0 )
2626 add_Tensor = CallFunction (aten .add .Tensor ,
@@ -56,7 +56,7 @@ def target_pattern(
5656 eps : float ,
5757 trigger_completion_at_end : bool ,
5858 ):
59- all_reduce_output = torch . ops . trtllm . allreduce (
59+ all_reduce_output = allreduce_func (
6060 input , residual , norm_weight , None , None , workspace ,
6161 mapping .tp_group , int (strategy ),
6262 int (AllReduceFusionOp .RESIDUAL_RMS_NORM ), float (eps ),
@@ -111,10 +111,11 @@ def check_non_ub_strategy(match, strategy_node) -> bool:
111111
112112
113113def register_ar_residual_norm_out_fp8_quant (custom_pass : PatternMatcherPass ,
114- mapping : Mapping ):
114+ mapping : Mapping ,
115+ allreduce_func : Callable ):
115116 input_node = KeywordArg ("input" )
116117 strategy_node = KeywordArg ("strategy" )
117- allreduce_default = CallFunction (torch . ops . trtllm . allreduce .default ,
118+ allreduce_default = CallFunction (allreduce_func .default ,
118119 input_node ,
119120 KeywordArg ("residual" ),
120121 KeywordArg ("gamma" ),
@@ -165,7 +166,7 @@ def target_pattern(
165166 scale : torch .Tensor ,
166167 trigger_completion_at_end : bool ,
167168 ):
168- allreduce = torch . ops . trtllm . allreduce (
169+ allreduce = allreduce_func (
169170 input , residual , gamma , scale , None , workspace , mapping .tp_group ,
170171 int (strategy ),
171172 int (AllReduceFusionOp .RESIDUAL_RMS_NORM_OUT_QUANT_FP8 ), float (eps ),
@@ -188,10 +189,11 @@ def extra_check(match: Match) -> bool:
188189
189190
190191def register_ar_residual_norm_fp8_quant (custom_pass : PatternMatcherPass ,
191- mapping : Mapping ):
192+ mapping : Mapping ,
193+ allreduce_func : Callable ):
192194 input_node = KeywordArg ("input" )
193195 strategy_node = KeywordArg ("strategy" )
194- allreduce_default = CallFunction (torch . ops . trtllm . allreduce .default ,
196+ allreduce_default = CallFunction (allreduce_func .default ,
195197 input_node ,
196198 KeywordArg ("residual" ),
197199 KeywordArg ("gamma" ),
@@ -242,7 +244,7 @@ def target_pattern(
242244 scale : torch .Tensor ,
243245 trigger_completion_at_end : bool ,
244246 ):
245- allreduce = torch . ops . trtllm . allreduce (
247+ allreduce = allreduce_func (
246248 input , residual , gamma , scale , None , workspace , mapping .tp_group ,
247249 int (strategy ), int (AllReduceFusionOp .RESIDUAL_RMS_NORM_QUANT_FP8 ),
248250 float (eps ), trigger_completion_at_end )
@@ -264,10 +266,11 @@ def extra_check(match: Match) -> bool:
264266
265267
266268def register_ar_residual_norm_out_fp4_quant (custom_pass : PatternMatcherPass ,
267- mapping : Mapping ):
269+ mapping : Mapping ,
270+ allreduce_func : Callable ):
268271 input_node = KeywordArg ("input" )
269272 strategy_node = KeywordArg ("strategy" )
270- allreduce_default = CallFunction (torch . ops . trtllm . allreduce .default ,
273+ allreduce_default = CallFunction (allreduce_func .default ,
271274 input_node ,
272275 KeywordArg ("residual" ),
273276 KeywordArg ("gamma" ),
@@ -313,7 +316,7 @@ def target_pattern(
313316 scale : torch .Tensor ,
314317 trigger_completion_at_end : bool ,
315318 ):
316- allreduce = torch . ops . trtllm . allreduce (
319+ allreduce = allreduce_func (
317320 input , residual , gamma , scale , None , workspace , mapping .tp_group ,
318321 int (strategy ),
319322 int (AllReduceFusionOp .RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4 ),
@@ -336,10 +339,11 @@ def extra_check(match: Match) -> bool:
336339
337340
338341def register_ar_residual_norm_fp4_quant (custom_pass : PatternMatcherPass ,
339- mapping : Mapping ):
342+ mapping : Mapping ,
343+ allreduce_func : Callable ):
340344 input_node = KeywordArg ("input" )
341345 strategy_node = KeywordArg ("strategy" )
342- allreduce_default = CallFunction (torch . ops . trtllm . allreduce .default ,
346+ allreduce_default = CallFunction (allreduce_func .default ,
343347 input_node ,
344348 KeywordArg ("residual" ),
345349 KeywordArg ("gamma" ),
@@ -385,7 +389,7 @@ def target_pattern(
385389 scale : torch .Tensor ,
386390 trigger_completion_at_end : bool ,
387391 ):
388- allreduce = torch . ops . trtllm . allreduce (
392+ allreduce = allreduce_func (
389393 input , residual , gamma , scale , None , workspace , mapping .tp_group ,
390394 int (strategy ), int (AllReduceFusionOp .RESIDUAL_RMS_NORM_QUANT_NVFP4 ),
391395 float (eps ), trigger_completion_at_end )
@@ -407,17 +411,20 @@ def extra_check(match: Match) -> bool:
407411
408412
409413def register_ub_patterns (custom_passes : List [PatternMatcherPass ],
410- mapping : Mapping ):
414+ mapping : Mapping , allreduce_func : Callable ):
411415
412416 def register_convert_supported_ar_to_ub (custom_pass : PatternMatcherPass ):
413417 strategy = int (AllReduceStrategy .AUTO )
414418 input_node = KeywordArg ('input' )
415419 fusion = KeywordArg ('fusion_op' )
416- trtllm_allreduce_default = CallFunction (
417- torch .ops .trtllm .allreduce .default , input_node ,
418- KeywordArg ('residual_in' ), KeywordArg ('gamma' ), KeywordArg ('scale' ),
419- None , Ignored (), mapping .tp_group , strategy , fusion ,
420- KeywordArg ('eps' ), Ignored ())
420+ trtllm_allreduce_default = CallFunction (allreduce_func .default ,
421+ input_node ,
422+ KeywordArg ('residual_in' ),
423+ KeywordArg ('gamma' ),
424+ KeywordArg ('scale' ), None ,
425+ Ignored (), mapping .tp_group ,
426+ strategy , fusion ,
427+ KeywordArg ('eps' ), Ignored ())
421428
422429 def empty_convert_supported_ar_to_ub (
423430 input : torch .Tensor ,
@@ -667,7 +674,7 @@ def register_ub_finalize_patterns(custom_pass: PatternMatcherPass):
667674 torch .ops .trtllm .userbuffers_allreduce_finalize .default ,
668675 KeywordArg ("sharded_residual" ), False )
669676 trtllm_allreduce_default = CallFunction (
670- torch .ops .trtllm .allreduce . default , KeywordArg ("input" ),
677+ torch .ops .trtllm .allreduce , KeywordArg ("input" ),
671678 trtllm_userbuffers_allreduce_finalize_default , KeywordArg ("gamma" ),
672679 KeywordArg ("scale" ), Ignored (), Ignored (), mapping .tp_group ,
673680 int (AllReduceStrategy .UB ), KeywordArg ("fusion_op" ),
@@ -718,15 +725,28 @@ def target_finalize_pattern(
718725
719726def register_ar_fusions (custom_passes : List [PatternMatcherPass ],
720727 mapping : Mapping , enable_ub : bool ):
721- register_ar_residual_norm (custom_passes [- 1 ], mapping )
728+ register_ar_residual_norm (custom_passes [- 1 ], mapping ,
729+ torch .ops .trtllm .allreduce )
730+ register_ar_residual_norm (custom_passes [- 1 ], mapping ,
731+ torch .ops .trtllm .tunable_allreduce )
722732
723733 custom_passes .append (PatternMatcherPass ())
724- register_ar_residual_norm_fp8_quant (custom_passes [- 1 ], mapping )
725- register_ar_residual_norm_fp4_quant (custom_passes [- 1 ], mapping )
726- # AR-Residual-Norm-Out-Quant-X is not supported by Userbuffers kernel.
727- if not enable_ub :
728- register_ar_residual_norm_out_fp8_quant (custom_passes [- 1 ], mapping )
729- register_ar_residual_norm_out_fp4_quant (custom_passes [- 1 ], mapping )
734+ for allreduce_func in [
735+ torch .ops .trtllm .allreduce , torch .ops .trtllm .tunable_allreduce
736+ ]:
737+ register_ar_residual_norm_fp8_quant (custom_passes [- 1 ], mapping ,
738+ allreduce_func )
739+ register_ar_residual_norm_fp4_quant (custom_passes [- 1 ], mapping ,
740+ allreduce_func )
741+
742+ # AR-Residual-Norm-Out-Quant-X is not supported by Userbuffers kernel.
743+ if not enable_ub :
744+ register_ar_residual_norm_out_fp8_quant (custom_passes [- 1 ], mapping ,
745+ allreduce_func )
746+ register_ar_residual_norm_out_fp4_quant (custom_passes [- 1 ], mapping ,
747+ allreduce_func )
730748
731749 if enable_ub :
732- register_ub_patterns (custom_passes , mapping )
750+ register_ub_patterns (custom_passes , mapping , torch .ops .trtllm .allreduce )
751+ register_ub_patterns (custom_passes , mapping ,
752+ torch .ops .trtllm .tunable_allreduce )
0 commit comments