@@ -722,18 +722,20 @@ def register_nccl_symmetric_patterns(custom_passes: List[PatternMatcherPass],
722722
723723 def register_convert_supported_ar_to_nccl_symmetric (
724724 custom_pass : PatternMatcherPass ):
725- strategy = int ( AllReduceStrategy . NCCL_SYMMETRIC )
725+ # Use KeywordArg for all arguments since the actual call uses keyword arguments
726726 input_node = KeywordArg ('input' )
727- # Use actual schema parameter names for matching: 'residual', 'norm_weight', 'op', 'group'
727+ # Use actual schema parameter names for matching: 'residual', 'norm_weight', 'op', 'group', 'strategy'
728728 # But keep 'residual_in', 'gamma', 'fusion_op' in function signatures for clarity
729729 fusion = KeywordArg ('op' ) # Schema uses 'op', not 'fusion_op'
730730 group_arg = KeywordArg (
731731 'group' ) # Use KeywordArg to match keyword arguments
732+ strategy_arg = KeywordArg (
733+ 'strategy' ) # Use KeywordArg to match keyword arguments
732734 trtllm_allreduce_default = CallFunction (
733735 torch .ops .trtllm .allreduce .default , input_node ,
734736 KeywordArg ('residual' ), KeywordArg ('norm_weight' ),
735- KeywordArg ('scale' ), None , Ignored (), group_arg , strategy , fusion ,
736- KeywordArg ('eps' ), Ignored ())
737+ KeywordArg ('scale' ), None , Ignored (), group_arg , strategy_arg ,
738+ fusion , KeywordArg ('eps' ), Ignored ())
737739
738740 def empty_convert_supported_ar_to_nccl_symmetric (
739741 input : torch .Tensor ,
@@ -774,6 +776,21 @@ def extra_check_convert_supported_ar_to_nccl_symmetric(
774776 f"pattern_to_node keys: { list (match .ctx .pattern_to_node .keys ())} "
775777 )
776778
779+ # Verify strategy is NCCL_SYMMETRIC
780+ try :
781+ strategy_value = match .ctx .pattern_to_node [strategy_arg ]
782+ if strategy_value != int (AllReduceStrategy .NCCL_SYMMETRIC ):
783+ logger .debug (
784+ f"[NCCL_SYMMETRIC] Pattern: extra_check failed: strategy={ strategy_value } "
785+ f"is not NCCL_SYMMETRIC ({ int (AllReduceStrategy .NCCL_SYMMETRIC )} )"
786+ )
787+ return False
788+ except KeyError as e :
789+ logger .debug (
790+ f"[NCCL_SYMMETRIC] Pattern: extra_check failed: strategy not found: { e } "
791+ )
792+ return False
793+
777794 # Verify group matches mapping.tp_group
778795 try :
779796 group_value = match .ctx .pattern_to_node [group_arg ]
@@ -817,8 +834,8 @@ def extra_check_convert_supported_ar_to_nccl_symmetric(
817834 return False
818835
819836 logger .debug (
820- f"[NCCL_SYMMETRIC] Pattern: extra_check passed: fusion_value={ fusion_value } , group= { group_value } "
821- )
837+ f"[NCCL_SYMMETRIC] Pattern: extra_check passed: fusion_value={ fusion_value } , "
838+ f"strategy= { strategy_value } , group= { group_value } " )
822839 return True
823840
824841 logger .debug (
0 commit comments