Skip to content

Commit 2d0cc42

Browse files
[Flexible Outputs] Allow multi-output matching
Adds support for pattern matching on multiple outputs for optimizations involving ops like Split. ``` def pattern(self, op, x): relu = op.Relu(x) return op.Split(relu, _allow_flexible_outputs=True) ``` Fixes #2581
1 parent 1077da7 commit 2d0cc42

File tree

4 files changed

+101
-13
lines changed

4 files changed

+101
-13
lines changed

onnxscript/rewriter/_matcher.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -310,9 +310,12 @@ def _match_single_output_node(
310310
if output_values is None:
311311
# TODO(rama): Is this a valid (useful) case?
312312
return match
313-
if check_removable and not _valid_to_replace(match.nodes, output_values):
314-
# TODO(rama): Match status should be updated to reflect failure reason.
315-
return match.fail("Matched nodes have other uses preventing replacement.")
313+
# Skip removability check for flexible output nodes since they may have
314+
# additional outputs beyond those captured in the pattern
315+
if check_removable and not pattern.output_node.allow_flexible_outputs:
316+
if not _valid_to_replace(match.nodes, output_values):
317+
# TODO(rama): Match status should be updated to reflect failure reason.
318+
return match.fail("Matched nodes have other uses preventing replacement.")
316319

317320
match.outputs.extend(output_values)
318321
return match

onnxscript/rewriter/_pattern_ir.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ def __call__(
248248
_outputs: int | list[str | None] = 1,
249249
_allow_other_attributes: bool | None = None,
250250
_allow_other_inputs: bool | None = None,
251+
_allow_flexible_outputs: bool | None = None,
251252
_check: Callable | None = None,
252253
**kwargs,
253254
):
@@ -280,6 +281,7 @@ def __call__(
280281
_outputs,
281282
allow_other_attributes=_allow_other_attributes,
282283
allow_other_inputs=_allow_other_inputs,
284+
allow_flexible_outputs=_allow_flexible_outputs,
283285
check=_check,
284286
)
285287
self.pattern_builder.add_node(node_pattern)
@@ -440,6 +442,7 @@ def __init__(
440442
*,
441443
allow_other_attributes: bool | None,
442444
allow_other_inputs: bool | None,
445+
allow_flexible_outputs: bool | None = None,
443446
check: Callable | None = None,
444447
):
445448
if allow_other_attributes is None:
@@ -448,12 +451,16 @@ def __init__(
448451
if allow_other_inputs is None:
449452
# TODO(rama): Should we default to True? For now, we preserve the current behavior.
450453
allow_other_inputs = False
454+
if allow_flexible_outputs is None:
455+
# Default behavior: do not match flexible outputs
456+
allow_flexible_outputs = False
451457
self.domain = domain
452458
self.op = StringConstantPattern(op) if isinstance(op, str) else op
453459
self.inputs = [_to_value_pattern(x) for x in inputs]
454460
self.attributes = attributes
455461
self.allow_other_attributes = allow_other_attributes
456462
self.allow_other_inputs = allow_other_inputs
463+
self.allow_flexible_outputs = allow_flexible_outputs
457464
self._check = check
458465
# In the common case, domain and op are constants, which can be used to optimize matching.
459466
if isinstance(op, str) and isinstance(domain, StringConstantPattern):

onnxscript/rewriter/_rewrite_rule.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import abc
88
import dataclasses
9+
import inspect
910
import itertools
1011
from typing import (
1112
Callable,
@@ -214,11 +215,14 @@ class ReplacementPatternFunction:
214215

215216
def __init__(self, function) -> None:
216217
self._function = function
218+
# Cache signature inspection to avoid repeated introspection on hot path
219+
self._accepts_match = "_match" in inspect.signature(function).parameters
217220

218221
def get_replacement(self, match: _basics.MatchResult) -> ReplacementSubgraph | None:
219222
context = RewriterContext()
223+
bindings = match.bindings if not self._accepts_match else {**match.bindings, "_match": match}
220224
try:
221-
new_outputs = self._function(context, **match.bindings)
225+
new_outputs = self._function(context, **bindings)
222226
except _basics.MatchFailureError as e:
223227
match.fail(e.reason, list(e.failure_sources))
224228
return None
@@ -313,6 +317,11 @@ def __init__(
313317
# Initialize the base pattern matching functionality
314318
super().__init__(target_pattern, condition_function, matcher, verbose, name)
315319

320+
# Check if any node in the pattern uses flexible outputs (cache for hot path)
321+
self._has_flexible_outputs = any(
322+
node.allow_flexible_outputs for node in self._target_pattern._nodes
323+
)
324+
316325
if not isinstance(replacement_pattern, ReplacementPatternFunction):
317326
replacement_pattern = ReplacementPatternFunction(replacement_pattern)
318327
self._replacement_pattern = replacement_pattern
@@ -357,7 +366,8 @@ def try_rewrite(
357366
_basics.MatchStatus.REPLACEMENT_FAILED,
358367
)
359368
return None
360-
if len(replacement_subgraph.new_outputs) != self._target_pattern.num_outputs:
369+
370+
if not self._has_flexible_outputs and len(replacement_subgraph.new_outputs) != self._target_pattern.num_outputs:
361371
raise ValueError(
362372
f"Number of outputs from replacement function does not match the number of outputs from the target pattern. "
363373
f"Expected {self._target_pattern.num_outputs}, but got {len(replacement_subgraph.new_outputs)}."
@@ -766,14 +776,33 @@ def _apply_to_graph_or_function(
766776
for n in delta.new_nodes:
767777
n.metadata_props[RULE_NAME_TAG] = rule.name
768778

769-
convenience.replace_nodes_and_values(
770-
graph_or_function,
771-
node,
772-
delta.match.nodes if rule.remove_nodes else [],
773-
delta.new_nodes,
774-
delta.match.outputs,
775-
delta.new_outputs,
776-
)
779+
# Check if this is a flexible output case (matched node has more outputs than captured)
780+
flexible_node = None
781+
for matched_node in delta.match.nodes:
782+
if len(matched_node.outputs) > len(delta.match.outputs):
783+
flexible_node = matched_node
784+
break
785+
786+
if flexible_node and len(delta.new_outputs) == len(flexible_node.outputs):
787+
# Flexible output replacement: replace all outputs of the flexible node
788+
convenience.replace_nodes_and_values(
789+
graph_or_function,
790+
node,
791+
delta.match.nodes if rule.remove_nodes else [],
792+
delta.new_nodes,
793+
flexible_node.outputs,
794+
delta.new_outputs,
795+
)
796+
else:
797+
# Standard replacement
798+
convenience.replace_nodes_and_values(
799+
graph_or_function,
800+
node,
801+
delta.match.nodes if rule.remove_nodes else [],
802+
delta.new_nodes,
803+
delta.match.outputs,
804+
delta.new_outputs,
805+
)
777806

778807
if merge_metadata:
779808
_default_metadata_merger.copy_merged_metadata(

onnxscript/rewriter/pattern_test.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -989,5 +989,54 @@ def test_pattern_builder_context(self):
989989
self.assertEqual(ops, ["Op1", "Op2", "Add", "Op3", "Mul"])
990990

991991

992+
class FlexibleOutputTest(unittest.TestCase):
993+
"""Test patterns with flexible output counts using _allow_flexible_outputs."""
994+
995+
def test_flexible_outputs_with_split(self):
996+
"""Test that _allow_flexible_outputs works for Split with varying output counts."""
997+
998+
def relu_split_pattern(op, x):
999+
relu = op.Relu(x)
1000+
return op.Split(relu, _allow_flexible_outputs=True)
1001+
1002+
def relu_split_rewrite(op, _match=None, x=None):
1003+
if x is None or _match is None:
1004+
return None
1005+
1006+
split = next((n for n in _match.nodes if n.op_type == "Split"), None)
1007+
if not split:
1008+
return None
1009+
1010+
num_outputs = len(split.outputs)
1011+
split_results = op.Split(x, _outputs=num_outputs, **split.attributes)
1012+
1013+
return tuple(op.Relu(s) for s in split_results) if num_outputs > 1 else op.Relu(split_results)
1014+
1015+
rule = pattern.RewriteRule(relu_split_pattern, relu_split_rewrite)
1016+
1017+
# Test model with Relu -> Split pattern
1018+
model_proto = onnx.parser.parse_model(
1019+
"""
1020+
<ir_version: 7, opset_import: [ "" : 18]>
1021+
agraph (float[10] x) => (float[5] out1, float[5] out2)
1022+
{
1023+
relu_out = Relu(x)
1024+
out1, out2 = Split<axis=0, num_outputs=2>(relu_out)
1025+
}
1026+
"""
1027+
)
1028+
1029+
optimized = onnxscript.rewriter.rewrite(model_proto, pattern_rewrite_rules=[rule])
1030+
1031+
# Verify transformation: 1 Relu + 1 Split -> 1 Split + 2 Relu
1032+
def count_ops(proto, op_type):
1033+
return sum(1 for n in proto.graph.node if n.op_type == op_type)
1034+
1035+
self.assertEqual(count_ops(model_proto, "Relu"), 1)
1036+
self.assertEqual(count_ops(model_proto, "Split"), 1)
1037+
self.assertEqual(count_ops(optimized, "Relu"), 2)
1038+
self.assertEqual(count_ops(optimized, "Split"), 1)
1039+
1040+
9921041
if __name__ == "__main__":
9931042
unittest.main()

0 commit comments

Comments
 (0)