Skip to content

Commit 38cf331

Browse files
[Flexible Outputs] Allow multi-output matching
Adds support for pattern matching on multiple outputs for optimizations involving ops like Split. ``` def pattern(self, op, x): relu = op.Relu(x) return op.Split(relu, _allow_flexible_outputs=True) ``` Fixes #2581
1 parent 1077da7 commit 38cf331

File tree

4 files changed

+294
-13
lines changed

4 files changed

+294
-13
lines changed

onnxscript/rewriter/_matcher.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -310,9 +310,12 @@ def _match_single_output_node(
310310
if output_values is None:
311311
# TODO(rama): Is this a valid (useful) case?
312312
return match
313-
if check_removable and not _valid_to_replace(match.nodes, output_values):
314-
# TODO(rama): Match status should be updated to reflect failure reason.
315-
return match.fail("Matched nodes have other uses preventing replacement.")
313+
# Skip removability check for flexible output nodes since they may have
314+
# additional outputs beyond those captured in the pattern
315+
if check_removable and not pattern.output_node.allow_flexible_outputs:
316+
if not _valid_to_replace(match.nodes, output_values):
317+
# TODO(rama): Match status should be updated to reflect failure reason.
318+
return match.fail("Matched nodes have other uses preventing replacement.")
316319

317320
match.outputs.extend(output_values)
318321
return match

onnxscript/rewriter/_pattern_ir.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ def __call__(
248248
_outputs: int | list[str | None] = 1,
249249
_allow_other_attributes: bool | None = None,
250250
_allow_other_inputs: bool | None = None,
251+
_allow_flexible_outputs: bool | None = None,
251252
_check: Callable | None = None,
252253
**kwargs,
253254
):
@@ -280,6 +281,7 @@ def __call__(
280281
_outputs,
281282
allow_other_attributes=_allow_other_attributes,
282283
allow_other_inputs=_allow_other_inputs,
284+
allow_flexible_outputs=_allow_flexible_outputs,
283285
check=_check,
284286
)
285287
self.pattern_builder.add_node(node_pattern)
@@ -440,6 +442,7 @@ def __init__(
440442
*,
441443
allow_other_attributes: bool | None,
442444
allow_other_inputs: bool | None,
445+
allow_flexible_outputs: bool | None = None,
443446
check: Callable | None = None,
444447
):
445448
if allow_other_attributes is None:
@@ -448,12 +451,16 @@ def __init__(
448451
if allow_other_inputs is None:
449452
# TODO(rama): Should we default to True? For now, we preserve the current behavior.
450453
allow_other_inputs = False
454+
if allow_flexible_outputs is None:
455+
# Default behavior: do not match flexible outputs
456+
allow_flexible_outputs = False
451457
self.domain = domain
452458
self.op = StringConstantPattern(op) if isinstance(op, str) else op
453459
self.inputs = [_to_value_pattern(x) for x in inputs]
454460
self.attributes = attributes
455461
self.allow_other_attributes = allow_other_attributes
456462
self.allow_other_inputs = allow_other_inputs
463+
self.allow_flexible_outputs = allow_flexible_outputs
457464
self._check = check
458465
# In the common case, domain and op are constants, which can be used to optimize matching.
459466
if isinstance(op, str) and isinstance(domain, StringConstantPattern):

onnxscript/rewriter/_rewrite_rule.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import abc
88
import dataclasses
9+
import inspect
910
import itertools
1011
from typing import (
1112
Callable,
@@ -214,11 +215,14 @@ class ReplacementPatternFunction:
214215

215216
def __init__(self, function) -> None:
216217
self._function = function
218+
# Cache signature inspection to avoid repeated introspection on hot path
219+
self._accepts_match = "_match" in inspect.signature(function).parameters
217220

218221
def get_replacement(self, match: _basics.MatchResult) -> ReplacementSubgraph | None:
219222
context = RewriterContext()
223+
bindings = match.bindings if not self._accepts_match else {**match.bindings, "_match": match}
220224
try:
221-
new_outputs = self._function(context, **match.bindings)
225+
new_outputs = self._function(context, **bindings)
222226
except _basics.MatchFailureError as e:
223227
match.fail(e.reason, list(e.failure_sources))
224228
return None
@@ -313,6 +317,11 @@ def __init__(
313317
# Initialize the base pattern matching functionality
314318
super().__init__(target_pattern, condition_function, matcher, verbose, name)
315319

320+
# Check if any node in the pattern uses flexible outputs (cache for hot path)
321+
self._has_flexible_outputs = any(
322+
node.allow_flexible_outputs for node in self._target_pattern._nodes
323+
)
324+
316325
if not isinstance(replacement_pattern, ReplacementPatternFunction):
317326
replacement_pattern = ReplacementPatternFunction(replacement_pattern)
318327
self._replacement_pattern = replacement_pattern
@@ -357,7 +366,8 @@ def try_rewrite(
357366
_basics.MatchStatus.REPLACEMENT_FAILED,
358367
)
359368
return None
360-
if len(replacement_subgraph.new_outputs) != self._target_pattern.num_outputs:
369+
370+
if not self._has_flexible_outputs and len(replacement_subgraph.new_outputs) != self._target_pattern.num_outputs:
361371
raise ValueError(
362372
f"Number of outputs from replacement function does not match the number of outputs from the target pattern. "
363373
f"Expected {self._target_pattern.num_outputs}, but got {len(replacement_subgraph.new_outputs)}."
@@ -766,14 +776,33 @@ def _apply_to_graph_or_function(
766776
for n in delta.new_nodes:
767777
n.metadata_props[RULE_NAME_TAG] = rule.name
768778

769-
convenience.replace_nodes_and_values(
770-
graph_or_function,
771-
node,
772-
delta.match.nodes if rule.remove_nodes else [],
773-
delta.new_nodes,
774-
delta.match.outputs,
775-
delta.new_outputs,
776-
)
779+
# Check if this is a flexible output case (matched node has more outputs than captured)
780+
flexible_node = None
781+
for matched_node in delta.match.nodes:
782+
if len(matched_node.outputs) > len(delta.match.outputs):
783+
flexible_node = matched_node
784+
break
785+
786+
if flexible_node and len(delta.new_outputs) == len(flexible_node.outputs):
787+
# Flexible output replacement: replace all outputs of the flexible node
788+
convenience.replace_nodes_and_values(
789+
graph_or_function,
790+
node,
791+
delta.match.nodes if rule.remove_nodes else [],
792+
delta.new_nodes,
793+
flexible_node.outputs,
794+
delta.new_outputs,
795+
)
796+
else:
797+
# Standard replacement
798+
convenience.replace_nodes_and_values(
799+
graph_or_function,
800+
node,
801+
delta.match.nodes if rule.remove_nodes else [],
802+
delta.new_nodes,
803+
delta.match.outputs,
804+
delta.new_outputs,
805+
)
777806

778807
if merge_metadata:
779808
_default_metadata_merger.copy_merged_metadata(

onnxscript/rewriter/pattern_test.py

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -989,5 +989,247 @@ def test_pattern_builder_context(self):
989989
self.assertEqual(ops, ["Op1", "Op2", "Add", "Op3", "Mul"])
990990

991991

992+
class FlexibleOutputTest(unittest.TestCase):
993+
"""Test patterns with flexible output counts using _allow_flexible_outputs."""
994+
995+
def test_flexible_outputs_with_split(self):
996+
"""Test that _allow_flexible_outputs works for Split with varying output counts."""
997+
998+
def relu_split_pattern(op, x):
999+
relu = op.Relu(x)
1000+
return op.Split(relu, _allow_flexible_outputs=True)
1001+
1002+
def relu_split_rewrite(op, _match=None, x=None):
1003+
if x is None or _match is None:
1004+
return None
1005+
1006+
split = next((n for n in _match.nodes if n.op_type == "Split"), None)
1007+
if not split:
1008+
return None
1009+
1010+
num_outputs = len(split.outputs)
1011+
split_results = op.Split(x, _outputs=num_outputs, **split.attributes)
1012+
1013+
return tuple(op.Relu(s) for s in split_results) if num_outputs > 1 else op.Relu(split_results)
1014+
1015+
rule = pattern.RewriteRule(relu_split_pattern, relu_split_rewrite)
1016+
1017+
# Test model with Relu -> Split pattern
1018+
model_proto = onnx.parser.parse_model(
1019+
"""
1020+
<ir_version: 7, opset_import: [ "" : 18]>
1021+
agraph (float[10] x) => (float[5] out1, float[5] out2)
1022+
{
1023+
relu_out = Relu(x)
1024+
out1, out2 = Split<axis=0, num_outputs=2>(relu_out)
1025+
}
1026+
"""
1027+
)
1028+
1029+
optimized = onnxscript.rewriter.rewrite(model_proto, pattern_rewrite_rules=[rule])
1030+
1031+
# Verify transformation: 1 Relu + 1 Split -> 1 Split + 2 Relu
1032+
def count_ops(proto, op_type):
1033+
return sum(1 for n in proto.graph.node if n.op_type == op_type)
1034+
1035+
self.assertEqual(count_ops(model_proto, "Relu"), 1)
1036+
self.assertEqual(count_ops(model_proto, "Split"), 1)
1037+
self.assertEqual(count_ops(optimized, "Relu"), 2)
1038+
self.assertEqual(count_ops(optimized, "Split"), 1)
1039+
1040+
def test_flexible_outputs_without_match_parameter(self):
1041+
def pattern_func(op, x):
1042+
return op.Split(x, _allow_flexible_outputs=True)
1043+
1044+
def rewrite_func(op, x=None):
1045+
if x is None:
1046+
return None
1047+
return op.Split(op.Relu(x), _outputs=2)
1048+
1049+
rule = pattern.RewriteRule(pattern_func, rewrite_func)
1050+
1051+
model_proto = onnx.parser.parse_model(
1052+
"""
1053+
<ir_version: 7, opset_import: [ "" : 18]>
1054+
agraph (float[10] x) => (float[5] out1, float[5] out2)
1055+
{
1056+
out1, out2 = Split<axis=0, num_outputs=2>(x)
1057+
}
1058+
"""
1059+
)
1060+
1061+
optimized = onnxscript.rewriter.rewrite(model_proto, pattern_rewrite_rules=[rule])
1062+
1063+
def count_ops(proto, op_type):
1064+
return sum(1 for n in proto.graph.node if n.op_type == op_type)
1065+
1066+
self.assertEqual(count_ops(optimized, "Relu"), 1)
1067+
self.assertEqual(count_ops(optimized, "Split"), 1)
1068+
1069+
def test_output_count_validation_without_flexible(self):
1070+
def pattern_func(op, x):
1071+
return op.Relu(x)
1072+
1073+
def bad_rewrite_func(op, x=None):
1074+
return op.Split(x, _outputs=2)
1075+
1076+
rule = pattern.RewriteRule(pattern_func, bad_rewrite_func)
1077+
1078+
model_proto = onnx.parser.parse_model(
1079+
"""
1080+
<ir_version: 7, opset_import: [ "" : 18]>
1081+
agraph (float[10] x) => (float[10] out)
1082+
{
1083+
out = Relu(x)
1084+
}
1085+
"""
1086+
)
1087+
1088+
with self.assertRaises(ValueError) as ctx:
1089+
onnxscript.rewriter.rewrite(model_proto, pattern_rewrite_rules=[rule])
1090+
1091+
self.assertIn("Number of outputs", str(ctx.exception))
1092+
1093+
def test_standard_replacement_path(self):
1094+
def pattern_func(op, x):
1095+
return op.Relu(x)
1096+
1097+
def rewrite_func(op, x=None):
1098+
return op.Sigmoid(x)
1099+
1100+
rule = pattern.RewriteRule(pattern_func, rewrite_func)
1101+
1102+
model_proto = onnx.parser.parse_model(
1103+
"""
1104+
<ir_version: 7, opset_import: [ "" : 18]>
1105+
agraph (float[10] x) => (float[10] out)
1106+
{
1107+
out = Relu(x)
1108+
}
1109+
"""
1110+
)
1111+
1112+
optimized = onnxscript.rewriter.rewrite(model_proto, pattern_rewrite_rules=[rule])
1113+
1114+
def count_ops(proto, op_type):
1115+
return sum(1 for n in proto.graph.node if n.op_type == op_type)
1116+
1117+
self.assertEqual(count_ops(optimized, "Relu"), 0)
1118+
self.assertEqual(count_ops(optimized, "Sigmoid"), 1)
1119+
1120+
def test_flexible_outputs_with_three_outputs(self):
1121+
def pattern_func(op, x):
1122+
return op.Split(x, _allow_flexible_outputs=True)
1123+
1124+
def rewrite_func(op, _match=None, x=None):
1125+
if x is None or _match is None:
1126+
return None
1127+
1128+
split = next((n for n in _match.nodes if n.op_type == "Split"), None)
1129+
if not split:
1130+
return None
1131+
1132+
num_outputs = len(split.outputs)
1133+
relu = op.Relu(x)
1134+
split_results = op.Split(relu, _outputs=num_outputs, **split.attributes)
1135+
return split_results
1136+
1137+
rule = pattern.RewriteRule(pattern_func, rewrite_func)
1138+
1139+
model_proto = onnx.parser.parse_model(
1140+
"""
1141+
<ir_version: 7, opset_import: [ "" : 18]>
1142+
agraph (float[15] x) => (float[5] out1, float[5] out2, float[5] out3)
1143+
{
1144+
out1, out2, out3 = Split<axis=0, num_outputs=3>(x)
1145+
}
1146+
"""
1147+
)
1148+
1149+
optimized = onnxscript.rewriter.rewrite(model_proto, pattern_rewrite_rules=[rule])
1150+
1151+
def count_ops(proto, op_type):
1152+
return sum(1 for n in proto.graph.node if n.op_type == op_type)
1153+
1154+
self.assertEqual(count_ops(optimized, "Relu"), 1)
1155+
self.assertEqual(count_ops(optimized, "Split"), 1)
1156+
1157+
def test_flexible_outputs_with_single_output(self):
1158+
def pattern_func(op, x):
1159+
return op.Split(x, _allow_flexible_outputs=True)
1160+
1161+
def rewrite_func(op, _match=None, x=None):
1162+
if x is None or _match is None:
1163+
return None
1164+
1165+
split = next((n for n in _match.nodes if n.op_type == "Split"), None)
1166+
if not split:
1167+
return None
1168+
1169+
relu = op.Relu(x)
1170+
if len(split.outputs) == 1:
1171+
return op.Split(relu, _outputs=1, **split.attributes)
1172+
return relu
1173+
1174+
rule = pattern.RewriteRule(pattern_func, rewrite_func)
1175+
1176+
model_proto = onnx.parser.parse_model(
1177+
"""
1178+
<ir_version: 7, opset_import: [ "" : 18]>
1179+
agraph (float[10] x) => (float[10] out)
1180+
{
1181+
out = Split<axis=0>(x)
1182+
}
1183+
"""
1184+
)
1185+
1186+
optimized = onnxscript.rewriter.rewrite(model_proto, pattern_rewrite_rules=[rule])
1187+
1188+
def count_ops(proto, op_type):
1189+
return sum(1 for n in proto.graph.node if n.op_type == op_type)
1190+
1191+
self.assertEqual(count_ops(optimized, "Relu"), 1)
1192+
self.assertEqual(count_ops(optimized, "Split"), 1)
1193+
1194+
def test_flexible_outputs_with_partial_usage(self):
1195+
def pattern_func(op, x):
1196+
return op.Split(x, _allow_flexible_outputs=True)
1197+
1198+
def rewrite_func(op, _match=None, x=None):
1199+
if x is None or _match is None:
1200+
return None
1201+
1202+
split = next((n for n in _match.nodes if n.op_type == "Split"), None)
1203+
if not split:
1204+
return None
1205+
1206+
num_outputs = len(split.outputs)
1207+
relu = op.Relu(x)
1208+
return op.Split(relu, _outputs=num_outputs, **split.attributes)
1209+
1210+
rule = pattern.RewriteRule(pattern_func, rewrite_func)
1211+
1212+
model_proto = onnx.parser.parse_model(
1213+
"""
1214+
<ir_version: 7, opset_import: [ "" : 18]>
1215+
agraph (float[10] x) => (float[5] out1, float[5] out2, float[5] sum)
1216+
{
1217+
s1, s2 = Split<axis=0, num_outputs=2>(x)
1218+
out1 = Abs(s1)
1219+
out2 = Neg(s2)
1220+
sum = Add(out1, out2)
1221+
}
1222+
"""
1223+
)
1224+
1225+
optimized = onnxscript.rewriter.rewrite(model_proto, pattern_rewrite_rules=[rule])
1226+
1227+
def count_ops(proto, op_type):
1228+
return sum(1 for n in proto.graph.node if n.op_type == op_type)
1229+
1230+
self.assertEqual(count_ops(optimized, "Relu"), 1)
1231+
self.assertEqual(count_ops(optimized, "Split"), 1)
1232+
1233+
9921234
if __name__ == "__main__":
9931235
unittest.main()

0 commit comments

Comments
 (0)