Skip to content

Commit 138cb30

Browse files
Copilotjustinchuby
andauthored
Fix MatchResult.fail() call signature in redundant_scatter_nd.py (#2431)
The `fail` helper function in `onnxscript/rewriter/redundant_scatter_nd.py` was incorrectly passing multiple arguments to `MatchResult.fail()`, causing a TypeError when pattern matching failed. ## Problem The error occurred when the rewriter tried to report match failures with multiple failure sources: ```python return fail("The shape of 'data' and 'updates' are different.", data, updates) ``` This resulted in: ``` TypeError: MatchResult.fail() takes from 1 to 3 positional arguments but 4 were given ``` The issue was that `MatchResult.fail()` only accepts 2 parameters after `self`: - `reason: str` - the failure reason - `failure_source: Union[ir.Node, ir.Value, list[...]] | None` - a single item or list of failure sources But the helper function was passing all arguments directly: `MatchResult().fail(*args)`. ## Solution Modified the `fail` helper function to properly handle multiple failure sources by collecting them into a list when calling `MatchResult.fail()`: ```python def fail(reason, *failure_sources): if failure_sources: return onnxscript.rewriter.MatchResult().fail(reason, list(failure_sources)) else: return onnxscript.rewriter.MatchResult().fail(reason) ``` This change: - ✅ Fixes the TypeError for calls with multiple failure sources - ✅ Maintains backward compatibility for existing single-argument calls - ✅ Follows the same pattern used correctly in other rewriter modules like `matmul_add_to_gemm.py` ## Testing Verified that all existing call patterns in the file work correctly: - `fail("message")` - reason only - `fail("message", node)` - reason + single source - `fail("message", node1, node2)` - reason + multiple sources Fixes #2430. <!-- START COPILOT CODING AGENT TIPS --> --- 💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more [Copilot coding agent tips](https://gh.io/copilot-coding-agent-tips) in the docs. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
1 parent f4534ee commit 138cb30

File tree

1 file changed

+16
-15
lines changed

1 file changed

+16
-15
lines changed

onnxscript/rewriter/redundant_scatter_nd.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,6 @@
2424
from onnxscript.rewriter import pattern as orp
2525

2626

27-
def fail(*args):
28-
return onnxscript.rewriter.MatchResult().fail(*args)
29-
30-
3127
class ScatterAllDynamic(orp.RewriteRuleClassBase):
3228
def pattern(self, op, data, axis, transposed_data, updates):
3329
# Construct update-indices spanning an entire axis:
@@ -41,24 +37,26 @@ def pattern(self, op, data, axis, transposed_data, updates):
4137
def check(self, context, data, axis, transposed_data, **_):
4238
# Check that updated-indices represent the full range of the first dimension of the transposed data.
4339
# That is: check that the data.shape[axis] matches transposed_data.shape[0].
40+
result = onnxscript.rewriter.MatchResult()
4441
axis_value = ir_utils.get_singleton_value(axis)
4542
if not isinstance(axis_value, int):
46-
return fail("Axis value must be a constant integer.", axis)
43+
return result.fail("Axis value must be a constant integer.", axis)
4744
shape: ir.Shape | None = data.shape
4845
if shape is None:
49-
return fail("Data shape is not statically known.", data)
46+
return result.fail("Data shape is not statically known.", data)
5047
updated_dim_value = shape[axis_value]
5148
transposed_data_shape: ir.Shape | None = transposed_data.shape
5249
if transposed_data_shape is None:
53-
return fail("Transposed data shape is not statically known.", transposed_data)
50+
return result.fail(
51+
"Transposed data shape is not statically known.", transposed_data
52+
)
5453
actual_dim_value = transposed_data_shape[0]
5554
if updated_dim_value != actual_dim_value:
5655
# The first dimension of the transposed data does not match the updated dimension,
5756
# so we cannot apply this rule.
58-
return fail(
57+
return result.fail(
5958
"The first dimension of the transposed data does not match the updated dimension.",
60-
data,
61-
transposed_data,
59+
[data, transposed_data],
6260
)
6361
return True
6462

@@ -81,20 +79,23 @@ def check(self, context, data, indices, updates, **_):
8179
"""Check if the ScatterND is redundant due to static indices covering entire tensor."""
8280
# To validate data can be replaced directly by updates, we need to check the following:
8381
# 1. they have the same shape
82+
result = onnxscript.rewriter.MatchResult()
8483
if data.shape is None:
85-
return fail("The value 'data' shape is not statically known.", data)
84+
return result.fail("The value 'data' shape is not statically known.", data)
8685
if updates.shape is None:
87-
return fail("The value 'updates' shape is not statically known.", updates)
86+
return result.fail("The value 'updates' shape is not statically known.", updates)
8887
if data.shape != updates.shape:
89-
return fail("The shape of 'data' and 'updates' are different.", data, updates)
88+
return result.fail(
89+
"The shape of 'data' and 'updates' are different.", [data, updates]
90+
)
9091

9192
# 2. the indices is referring to the whole data, which is from 0 to data.shape[0]
9293
if indices.const_value is None:
93-
return fail("The value 'indices' is not statically known.", indices)
94+
return result.fail("The value 'indices' is not statically known.", indices)
9495
expected_indices = [[i] for i in range(data.shape[0])]
9596
actual_indices = indices.const_value.numpy().tolist()
9697
if actual_indices != expected_indices:
97-
return fail("The 'indices' is not referring to the whole data.", indices)
98+
return result.fail("The 'indices' is not referring to the whole data.", indices)
9899

99100
return True
100101

0 commit comments

Comments
 (0)