Skip to content

Commit d4249ed

Browse files
committed
fresh try
Signed-off-by: Ludwig Schneider <lschneider@nvidia.com>
1 parent 46d6ddd commit d4249ed

File tree

4 files changed

+138
-118
lines changed

4 files changed

+138
-118
lines changed

tensorrt_llm/_torch/compilation/backend.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -100,21 +100,24 @@ def optimize(
100100
example_inputs: List[torch.Tensor],
101101
):
102102
graph = gm.graph
103-
logger.debug(
104-
f"[NCCL_SYMMETRIC] Pattern: optimize() called with {len(self.custom_passes)} custom passes"
105-
)
103+
if self.rank == 0:
104+
logger.debug(
105+
f"[NCCL_SYMMETRIC] Pattern: optimize() called with {len(self.custom_passes)} custom passes"
106+
)
106107
for i, custom_pass in enumerate(self.custom_passes):
107108
match_count = custom_pass.apply(graph)
108109
self.match_count.append(match_count)
109-
logger.debug(
110-
f"[NCCL_SYMMETRIC] Pattern: Pass {i} applied, matched {match_count} patterns"
111-
)
110+
if self.rank == 0:
111+
logger.debug(
112+
f"[NCCL_SYMMETRIC] Pattern: Pass {i} applied, matched {match_count} patterns"
113+
)
112114
while self.match_count[-1]:
113115
match_count = custom_pass.apply(graph)
114116
self.match_count.append(match_count)
115-
logger.debug(
116-
f"[NCCL_SYMMETRIC] Pattern: Pass {i} re-applied, matched {match_count} more patterns"
117-
)
117+
if self.rank == 0:
118+
logger.debug(
119+
f"[NCCL_SYMMETRIC] Pattern: Pass {i} re-applied, matched {match_count} more patterns"
120+
)
118121
graph.eliminate_dead_code()
119122
# After this pass, cannot run any dce!!!
120123
remove_copy_for_mutates_args(graph)

tensorrt_llm/_torch/compilation/pattern_applier.py

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch.nn as nn
1111
from torch.fx import GraphModule
1212

13+
import tensorrt_llm
1314
from tensorrt_llm import logger
1415
from tensorrt_llm.mapping import Mapping
1516

@@ -58,10 +59,11 @@ def __call__(self, gm: GraphModule, example_inputs: List[torch.Tensor]) -> calla
5859
This reuses the exact same pattern application code that Backend uses,
5960
ensuring consistency with the rest of TRT-LLM.
6061
"""
61-
# Debug: Print graph structure to understand why patterns aren't matching
62-
logger.debug(
63-
f"[NCCL_SYMMETRIC] PatternOnlyBackend: Graph has {len(list(gm.graph.nodes))} nodes"
64-
)
62+
# Debug: Print graph structure to understand why patterns aren't matching (rank 0 only)
63+
if tensorrt_llm.mpi_rank() == 0:
64+
logger.debug(
65+
f"[NCCL_SYMMETRIC] PatternOnlyBackend: Graph has {len(list(gm.graph.nodes))} nodes"
66+
)
6567
# Log all allreduce calls in the graph
6668
allreduce_nodes = []
6769
for n in gm.graph.nodes:
@@ -70,9 +72,10 @@ def __call__(self, gm: GraphModule, example_inputs: List[torch.Tensor]) -> calla
7072
if "allreduce" in target_str.lower():
7173
allreduce_nodes.append(n)
7274

73-
logger.debug(
74-
f"[NCCL_SYMMETRIC] PatternOnlyBackend: Found {len(allreduce_nodes)} allreduce nodes in graph"
75-
)
75+
if tensorrt_llm.mpi_rank() == 0:
76+
logger.debug(
77+
f"[NCCL_SYMMETRIC] PatternOnlyBackend: Found {len(allreduce_nodes)} allreduce nodes in graph"
78+
)
7679
for i, node in enumerate(allreduce_nodes):
7780
# Extract strategy from args (it's typically the 7th positional arg)
7881
strategy_val = None
@@ -102,26 +105,28 @@ def __call__(self, gm: GraphModule, example_inputs: List[torch.Tensor]) -> calla
102105
except Exception:
103106
pass
104107

105-
logger.debug(
106-
f"[NCCL_SYMMETRIC] PatternOnlyBackend: AllReduce node {i}: "
107-
f"target={node.target}, args={len(node.args)}, "
108-
f"strategy={strategy_val} (type={type(strategy_val)}), "
109-
f"fusion_op={fusion_val} (type={type(fusion_val)}), "
110-
f"kwargs={node.kwargs}, {schema_info}"
111-
)
112-
# Log first few args to understand structure
113-
if len(node.args) > 0:
114-
args_slice = node.args[0:5] if len(node.args) >= 5 else node.args
115-
logger.debug(
116-
f"[NCCL_SYMMETRIC] PatternOnlyBackend: AllReduce node {i} "
117-
f"args[0:5]={args_slice}"
118-
)
119-
elif node.kwargs:
120-
kwarg_keys = list(node.kwargs.keys())
108+
if tensorrt_llm.mpi_rank() == 0:
121109
logger.debug(
122-
f"[NCCL_SYMMETRIC] PatternOnlyBackend: AllReduce node {i} "
123-
f"using keyword args: {kwarg_keys}"
110+
f"[NCCL_SYMMETRIC] PatternOnlyBackend: AllReduce node {i}: "
111+
f"target={node.target}, args={len(node.args)}, "
112+
f"strategy={strategy_val} (type={type(strategy_val)}), "
113+
f"fusion_op={fusion_val} (type={type(fusion_val)}), "
114+
f"kwargs={node.kwargs}, {schema_info}"
124115
)
116+
# Log first few args to understand structure (rank 0 only)
117+
if tensorrt_llm.mpi_rank() == 0:
118+
if len(node.args) > 0:
119+
args_slice = node.args[0:5] if len(node.args) >= 5 else node.args
120+
logger.debug(
121+
f"[NCCL_SYMMETRIC] PatternOnlyBackend: AllReduce node {i} "
122+
f"args[0:5]={args_slice}"
123+
)
124+
elif node.kwargs:
125+
kwarg_keys = list(node.kwargs.keys())
126+
logger.debug(
127+
f"[NCCL_SYMMETRIC] PatternOnlyBackend: AllReduce node {i} "
128+
f"using keyword args: {kwarg_keys}"
129+
)
125130

126131
# Use the existing optimize() method which already handles:
127132
# - recover_pass()

0 commit comments

Comments
 (0)