@@ -748,14 +748,29 @@ def target_finalize_pattern(
748748
749749
750750def register_nccl_symmetric_patterns (custom_passes : List [PatternMatcherPass ],
751- mapping : Mapping ):
751+ mapping : Mapping ,
752+ match_auto_strategy : bool = False ):
753+ """
754+ Register NCCL_SYMMETRIC patterns.
755+
756+ Args:
757+ custom_passes: List of pattern matcher passes to register patterns in
758+ mapping: Mapping configuration for TP group
759+ match_auto_strategy: If True, match AUTO strategy nodes (strategy=3) and convert to NCCL_SYMMETRIC.
760+ If False, match NCCL_SYMMETRIC strategy nodes (strategy=8) directly.
761+ Should be True when NCCL_SYMMETRIC is requested but graph nodes have AUTO.
762+ """
752763
753764 def register_convert_supported_ar_to_nccl_symmetric (
754765 custom_pass : PatternMatcherPass ):
755766 # FX normalizes keyword arguments to positional arguments based on schema
756767 # Pattern should match positional arguments in schema order:
757768 # input, residual, norm_weight, scale, bias, workspace, group, strategy, op, eps, trigger_completion_at_end
758- strategy = int (AllReduceStrategy .NCCL_SYMMETRIC )
769+ # When match_auto_strategy=True, match AUTO strategy nodes (strategy=3) since graph nodes
770+ # typically have AUTO even when NCCL_SYMMETRIC is requested in config. The pattern will convert them to NCCL_SYMMETRIC.
771+ # When match_auto_strategy=False, match NCCL_SYMMETRIC strategy nodes (strategy=8) directly.
772+ strategy = int (AllReduceStrategy .AUTO ) if match_auto_strategy else int (
773+ AllReduceStrategy .NCCL_SYMMETRIC )
759774 input_node = KeywordArg ('input' )
760775 fusion = KeywordArg ('op' ) # Schema uses 'op', not 'fusion_op'
761776 trtllm_allreduce_default = CallFunction (
@@ -765,8 +780,15 @@ def register_convert_supported_ar_to_nccl_symmetric(
765780 fusion , KeywordArg ('eps' ), Ignored ())
766781
767782 if tensorrt_llm .mpi_rank () == 0 :
768- logger .debug ("[NCCL_SYMMETRIC] Pattern: Registering pattern with "
769- f"strategy={ strategy } , group={ mapping .tp_group } " )
783+ if match_auto_strategy :
784+ logger .debug (
785+ "[NCCL_SYMMETRIC] Pattern: Registering pattern to match AUTO strategy "
786+ f"(strategy={ strategy } ) and convert to NCCL_SYMMETRIC, group={ mapping .tp_group } "
787+ )
788+ else :
789+ logger .debug (
790+ "[NCCL_SYMMETRIC] Pattern: Registering pattern to match NCCL_SYMMETRIC strategy "
791+ f"(strategy={ strategy } ), group={ mapping .tp_group } " )
770792
771793 def empty_convert_supported_ar_to_nccl_symmetric (
772794 input : torch .Tensor ,
@@ -1085,5 +1107,8 @@ def register_ar_fusions(custom_passes: List[PatternMatcherPass],
10851107 if enable_ub :
10861108 register_ub_patterns (custom_passes , mapping )
10871109
1088- # Always register NCCL_SYMMETRIC patterns (they only match when strategy is explicitly NCCL_SYMMETRIC)
1089- register_nccl_symmetric_patterns (custom_passes , mapping )
1110+ # Register NCCL_SYMMETRIC patterns - by default match NCCL_SYMMETRIC strategy nodes directly
1111+ # Set match_auto_strategy=True only when NCCL_SYMMETRIC is requested but graph has AUTO nodes
1112+ register_nccl_symmetric_patterns (custom_passes ,
1113+ mapping ,
1114+ match_auto_strategy = False )
0 commit comments