Skip to content

Commit e8560f7

Browse files
Copilotgramalingam
andcommitted
Fix tutorial input field name and rename PatternMatchContext to MatchContext
Co-authored-by: gramalingam <10075881+gramalingam@users.noreply.github.com>
1 parent 23b2e12 commit e8560f7

File tree

5 files changed

+14
-14
lines changed

5 files changed

+14
-14
lines changed

docs/tutorial/rewriter/conditional_rewrite.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,13 @@ The final graph with the applied rewrite looks as follows:
5151

5252
![broadcast_rewrite](examples/img/broadcast_02.png){align=center}
5353

54-
# Using PatternMatchContext for Advanced Condition Checking
54+
# Using MatchContext for Advanced Condition Checking
5555

56-
The `context` parameter passed to condition functions is an instance of {py:class}`onnxscript.rewriter.PatternMatchContext`, which provides access to additional information about the pattern match that can be useful for sophisticated condition checking.
56+
The `context` parameter passed to condition functions is an instance of {py:class}`onnxscript.rewriter.MatchContext`, which provides access to additional information about the pattern match that can be useful for sophisticated condition checking.
5757

58-
## PatternMatchContext Properties
58+
## MatchContext Properties
5959

60-
The PatternMatchContext provides the following read-only properties:
60+
The MatchContext provides the following read-only properties:
6161

6262
- `model`: The entire ONNX model being matched
6363
- `graph_or_function`: The specific graph or function being matched
@@ -67,11 +67,11 @@ The PatternMatchContext provides the following read-only properties:
6767

6868
## Example Usage
6969

70-
Here's an example showing how to use the PatternMatchContext to implement more sophisticated condition checking:
70+
Here's an example showing how to use the MatchContext to implement more sophisticated condition checking:
7171

7272
```python
7373
def advanced_condition_check(context, x, y, **_):
74-
"""Example condition function using PatternMatchContext."""
74+
"""Example condition function using MatchContext."""
7575

7676
# Access the main node of the pattern match
7777
main_node = context.main_root_node
@@ -82,7 +82,7 @@ def advanced_condition_check(context, x, y, **_):
8282

8383
# Access the broader graph context and check that x occurs as a graph-input
8484
model = context.model
85-
if x not in model.graph.input:
85+
if x not in model.graph.inputs:
8686
return False
8787

8888
# You can inspect the matched nodes for advanced validation

onnxscript/rewriter/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"rewrite",
1010
"RewritePass",
1111
"MatchResult",
12-
"PatternMatchContext",
12+
"MatchContext",
1313
"RewriteRule",
1414
"RewriteRuleClassBase",
1515
"RewriteRuleSet",
@@ -32,7 +32,7 @@
3232
pattern,
3333
redundant_scatter_nd,
3434
)
35-
from onnxscript.rewriter._basics import MatchingTracer, MatchResult, MatchStatus, PatternMatchContext
35+
from onnxscript.rewriter._basics import MatchingTracer, MatchResult, MatchStatus, MatchContext
3636
from onnxscript.rewriter._rewrite_rule import (
3737
RewriterContext,
3838
RewriteRule,

onnxscript/rewriter/_basics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def print(self):
340340
print(separator)
341341

342342

343-
class PatternMatchContext:
343+
class MatchContext:
344344
"""A read-only context containing information about a pattern match.
345345
346346
This class captures information about the context describing a match to a given pattern,

onnxscript/rewriter/_rewrite_rule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def match(
131131
remove_nodes=check_nodes_are_removable,
132132
)
133133
if match:
134-
context = _basics.PatternMatchContext(model, graph_or_function, node, match)
134+
context = _basics.MatchContext(model, graph_or_function, node, match)
135135
for var in self._target_pattern.inputs:
136136
if var.name is not None:
137137
if var.name not in match.bindings:

onnxscript/rewriter/pattern_match_context_test.py renamed to onnxscript/rewriter/match_context_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
3-
"""Test for PatternMatchContext functionality."""
3+
"""Test for MatchContext functionality."""
44

55
import unittest
66

@@ -10,9 +10,9 @@
1010
from onnxscript.rewriter import pattern
1111

1212

13-
class PatternMatchContextTest(unittest.TestCase):
13+
class MatchContextTest(unittest.TestCase):
1414
def test_context_usage_in_condition_function(self):
15-
"""Test that PatternMatchContext can be meaningfully used in condition functions."""
15+
"""Test that MatchContext can be meaningfully used in condition functions."""
1616

1717
model_proto = onnx.parser.parse_model(
1818
"""

0 commit comments

Comments
 (0)