1010import torch .nn as nn
1111from torch .fx import GraphModule
1212
13+ import tensorrt_llm
1314from tensorrt_llm import logger
1415from tensorrt_llm .mapping import Mapping
1516
@@ -58,10 +59,11 @@ def __call__(self, gm: GraphModule, example_inputs: List[torch.Tensor]) -> calla
5859 This reuses the exact same pattern application code that Backend uses,
5960 ensuring consistency with the rest of TRT-LLM.
6061 """
61- # Debug: Print graph structure to understand why patterns aren't matching
62- logger .debug (
63- f"[NCCL_SYMMETRIC] PatternOnlyBackend: Graph has { len (list (gm .graph .nodes ))} nodes"
64- )
62+ # Debug: Print graph structure to understand why patterns aren't matching (rank 0 only)
63+ if tensorrt_llm .mpi_rank () == 0 :
64+ logger .debug (
65+ f"[NCCL_SYMMETRIC] PatternOnlyBackend: Graph has { len (list (gm .graph .nodes ))} nodes"
66+ )
6567 # Log all allreduce calls in the graph
6668 allreduce_nodes = []
6769 for n in gm .graph .nodes :
@@ -70,9 +72,10 @@ def __call__(self, gm: GraphModule, example_inputs: List[torch.Tensor]) -> calla
7072 if "allreduce" in target_str .lower ():
7173 allreduce_nodes .append (n )
7274
73- logger .debug (
74- f"[NCCL_SYMMETRIC] PatternOnlyBackend: Found { len (allreduce_nodes )} allreduce nodes in graph"
75- )
75+ if tensorrt_llm .mpi_rank () == 0 :
76+ logger .debug (
77+ f"[NCCL_SYMMETRIC] PatternOnlyBackend: Found { len (allreduce_nodes )} allreduce nodes in graph"
78+ )
7679 for i , node in enumerate (allreduce_nodes ):
7780 # Extract strategy from args (it's typically the 7th positional arg)
7881 strategy_val = None
@@ -102,26 +105,28 @@ def __call__(self, gm: GraphModule, example_inputs: List[torch.Tensor]) -> calla
102105 except Exception :
103106 pass
104107
105- logger .debug (
106- f"[NCCL_SYMMETRIC] PatternOnlyBackend: AllReduce node { i } : "
107- f"target={ node .target } , args={ len (node .args )} , "
108- f"strategy={ strategy_val } (type={ type (strategy_val )} ), "
109- f"fusion_op={ fusion_val } (type={ type (fusion_val )} ), "
110- f"kwargs={ node .kwargs } , { schema_info } "
111- )
112- # Log first few args to understand structure
113- if len (node .args ) > 0 :
114- args_slice = node .args [0 :5 ] if len (node .args ) >= 5 else node .args
115- logger .debug (
116- f"[NCCL_SYMMETRIC] PatternOnlyBackend: AllReduce node { i } "
117- f"args[0:5]={ args_slice } "
118- )
119- elif node .kwargs :
120- kwarg_keys = list (node .kwargs .keys ())
108+ if tensorrt_llm .mpi_rank () == 0 :
121109 logger .debug (
122- f"[NCCL_SYMMETRIC] PatternOnlyBackend: AllReduce node { i } "
123- f"using keyword args: { kwarg_keys } "
110+ f"[NCCL_SYMMETRIC] PatternOnlyBackend: AllReduce node { i } : "
111+ f"target={ node .target } , args={ len (node .args )} , "
112+ f"strategy={ strategy_val } (type={ type (strategy_val )} ), "
113+ f"fusion_op={ fusion_val } (type={ type (fusion_val )} ), "
114+ f"kwargs={ node .kwargs } , { schema_info } "
124115 )
116+ # Log first few args to understand structure (rank 0 only)
117+ if tensorrt_llm .mpi_rank () == 0 :
118+ if len (node .args ) > 0 :
119+ args_slice = node .args [0 :5 ] if len (node .args ) >= 5 else node .args
120+ logger .debug (
121+ f"[NCCL_SYMMETRIC] PatternOnlyBackend: AllReduce node { i } "
122+ f"args[0:5]={ args_slice } "
123+ )
124+ elif node .kwargs :
125+ kwarg_keys = list (node .kwargs .keys ())
126+ logger .debug (
127+ f"[NCCL_SYMMETRIC] PatternOnlyBackend: AllReduce node { i } "
128+ f"using keyword args: { kwarg_keys } "
129+ )
125130
126131 # Use the existing optimize() method which already handles:
127132 # - recover_pass()
0 commit comments