Skip to content

Commit 46d6ddd

Browse files
committed
asdf
Signed-off-by: Ludwig Schneider <lschneider@nvidia.com>
1 parent 4c0ef01 commit 46d6ddd

File tree

1 file changed

+23
-6
lines changed

1 file changed

+23
-6
lines changed

tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -722,18 +722,20 @@ def register_nccl_symmetric_patterns(custom_passes: List[PatternMatcherPass],
722722

723723
def register_convert_supported_ar_to_nccl_symmetric(
724724
custom_pass: PatternMatcherPass):
725-
strategy = int(AllReduceStrategy.NCCL_SYMMETRIC)
725+
# Use KeywordArg for all arguments since the actual call uses keyword arguments
726726
input_node = KeywordArg('input')
727-
# Use actual schema parameter names for matching: 'residual', 'norm_weight', 'op', 'group'
727+
# Use actual schema parameter names for matching: 'residual', 'norm_weight', 'op', 'group', 'strategy'
728728
# But keep 'residual_in', 'gamma', 'fusion_op' in function signatures for clarity
729729
fusion = KeywordArg('op') # Schema uses 'op', not 'fusion_op'
730730
group_arg = KeywordArg(
731731
'group') # Use KeywordArg to match keyword arguments
732+
strategy_arg = KeywordArg(
733+
'strategy') # Use KeywordArg to match keyword arguments
732734
trtllm_allreduce_default = CallFunction(
733735
torch.ops.trtllm.allreduce.default, input_node,
734736
KeywordArg('residual'), KeywordArg('norm_weight'),
735-
KeywordArg('scale'), None, Ignored(), group_arg, strategy, fusion,
736-
KeywordArg('eps'), Ignored())
737+
KeywordArg('scale'), None, Ignored(), group_arg, strategy_arg,
738+
fusion, KeywordArg('eps'), Ignored())
737739

738740
def empty_convert_supported_ar_to_nccl_symmetric(
739741
input: torch.Tensor,
@@ -774,6 +776,21 @@ def extra_check_convert_supported_ar_to_nccl_symmetric(
774776
f"pattern_to_node keys: {list(match.ctx.pattern_to_node.keys())}"
775777
)
776778

779+
# Verify strategy is NCCL_SYMMETRIC
780+
try:
781+
strategy_value = match.ctx.pattern_to_node[strategy_arg]
782+
if strategy_value != int(AllReduceStrategy.NCCL_SYMMETRIC):
783+
logger.debug(
784+
f"[NCCL_SYMMETRIC] Pattern: extra_check failed: strategy={strategy_value} "
785+
f"is not NCCL_SYMMETRIC ({int(AllReduceStrategy.NCCL_SYMMETRIC)})"
786+
)
787+
return False
788+
except KeyError as e:
789+
logger.debug(
790+
f"[NCCL_SYMMETRIC] Pattern: extra_check failed: strategy not found: {e}"
791+
)
792+
return False
793+
777794
# Verify group matches mapping.tp_group
778795
try:
779796
group_value = match.ctx.pattern_to_node[group_arg]
@@ -817,8 +834,8 @@ def extra_check_convert_supported_ar_to_nccl_symmetric(
817834
return False
818835

819836
logger.debug(
820-
f"[NCCL_SYMMETRIC] Pattern: extra_check passed: fusion_value={fusion_value}, group={group_value}"
821-
)
837+
f"[NCCL_SYMMETRIC] Pattern: extra_check passed: fusion_value={fusion_value}, "
838+
f"strategy={strategy_value}, group={group_value}")
822839
return True
823840

824841
logger.debug(

0 commit comments

Comments
 (0)