Skip to content

Commit f20fe4a

Browse files
committed
clever switching of graph applications
Signed-off-by: Ludwig Schneider <lschneider@nvidia.com>
1 parent d4249ed commit f20fe4a

File tree

2 files changed

+37
-6
lines changed

2 files changed

+37
-6
lines changed

tensorrt_llm/_torch/compilation/pattern_applier.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88

99
import torch
1010
import torch.nn as nn
11+
from torch._inductor.pattern_matcher import PatternMatcherPass
1112
from torch.fx import GraphModule
1213

1314
import tensorrt_llm
1415
from tensorrt_llm import logger
1516
from tensorrt_llm.mapping import Mapping
1617

1718
from .backend import Backend
19+
from .patterns.ar_residual_norm import register_nccl_symmetric_patterns
1820

1921

2022
def apply_nccl_symmetric_patterns_to_model(model: nn.Module, mapping: Mapping) -> nn.Module:
@@ -51,6 +53,10 @@ def __init__(self, mapping: Mapping):
5153
max_num_streams=1,
5254
mapping=mapping,
5355
)
56+
# Override custom_passes to register NCCL_SYMMETRIC patterns with match_auto_strategy=True
57+
# This ensures we match AUTO strategy nodes (which is what's in the graph) and convert to NCCL_SYMMETRIC
58+
self.custom_passes = [PatternMatcherPass()]
59+
register_nccl_symmetric_patterns(self.custom_passes, mapping, match_auto_strategy=True)
5460

5561
def __call__(self, gm: GraphModule, example_inputs: List[torch.Tensor]) -> callable:
5662
"""

tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -748,14 +748,29 @@ def target_finalize_pattern(
748748

749749

750750
def register_nccl_symmetric_patterns(custom_passes: List[PatternMatcherPass],
751-
mapping: Mapping):
751+
mapping: Mapping,
752+
match_auto_strategy: bool = False):
753+
"""
754+
Register NCCL_SYMMETRIC patterns.
755+
756+
Args:
757+
custom_passes: List of pattern matcher passes to register patterns in
758+
mapping: Mapping configuration for TP group
759+
match_auto_strategy: If True, match AUTO strategy nodes (strategy=3) and convert to NCCL_SYMMETRIC.
760+
If False, match NCCL_SYMMETRIC strategy nodes (strategy=8) directly.
761+
Should be True when NCCL_SYMMETRIC is requested but graph nodes have AUTO.
762+
"""
752763

753764
def register_convert_supported_ar_to_nccl_symmetric(
754765
custom_pass: PatternMatcherPass):
755766
# FX normalizes keyword arguments to positional arguments based on schema
756767
# Pattern should match positional arguments in schema order:
757768
# input, residual, norm_weight, scale, bias, workspace, group, strategy, op, eps, trigger_completion_at_end
758-
strategy = int(AllReduceStrategy.NCCL_SYMMETRIC)
769+
# When match_auto_strategy=True, match AUTO strategy nodes (strategy=3) since graph nodes
770+
# typically have AUTO even when NCCL_SYMMETRIC is requested in config. The pattern will convert them to NCCL_SYMMETRIC.
771+
# When match_auto_strategy=False, match NCCL_SYMMETRIC strategy nodes (strategy=8) directly.
772+
strategy = int(AllReduceStrategy.AUTO) if match_auto_strategy else int(
773+
AllReduceStrategy.NCCL_SYMMETRIC)
759774
input_node = KeywordArg('input')
760775
fusion = KeywordArg('op') # Schema uses 'op', not 'fusion_op'
761776
trtllm_allreduce_default = CallFunction(
@@ -765,8 +780,15 @@ def register_convert_supported_ar_to_nccl_symmetric(
765780
fusion, KeywordArg('eps'), Ignored())
766781

767782
if tensorrt_llm.mpi_rank() == 0:
768-
logger.debug("[NCCL_SYMMETRIC] Pattern: Registering pattern with "
769-
f"strategy={strategy}, group={mapping.tp_group}")
783+
if match_auto_strategy:
784+
logger.debug(
785+
"[NCCL_SYMMETRIC] Pattern: Registering pattern to match AUTO strategy "
786+
f"(strategy={strategy}) and convert to NCCL_SYMMETRIC, group={mapping.tp_group}"
787+
)
788+
else:
789+
logger.debug(
790+
"[NCCL_SYMMETRIC] Pattern: Registering pattern to match NCCL_SYMMETRIC strategy "
791+
f"(strategy={strategy}), group={mapping.tp_group}")
770792

771793
def empty_convert_supported_ar_to_nccl_symmetric(
772794
input: torch.Tensor,
@@ -1085,5 +1107,8 @@ def register_ar_fusions(custom_passes: List[PatternMatcherPass],
10851107
if enable_ub:
10861108
register_ub_patterns(custom_passes, mapping)
10871109

1088-
# Always register NCCL_SYMMETRIC patterns (they only match when strategy is explicitly NCCL_SYMMETRIC)
1089-
register_nccl_symmetric_patterns(custom_passes, mapping)
1110+
# Register NCCL_SYMMETRIC patterns - by default match NCCL_SYMMETRIC strategy nodes directly
1111+
# Set match_auto_strategy=True only when NCCL_SYMMETRIC is requested but graph has AUTO nodes
1112+
register_nccl_symmetric_patterns(custom_passes,
1113+
mapping,
1114+
match_auto_strategy=False)

0 commit comments

Comments
 (0)