Skip to content

Commit 02b5e5d

Browse files
authored
Merge branch 'main' into copilot/fix-2405
2 parents 2b26f63 + 38c4468 commit 02b5e5d

File tree

116 files changed

+5069
-4205
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

116 files changed

+5069
-4205
lines changed

.github/copilot-instructions.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
## Code Standards
2+
3+
### Required Before Each Commit
4+
- Run `lintrunner -a` before committing any changes to ensure proper code formatting
5+
- This will run lintrunner on all updated files to maintain consistent style

.lintrunner.toml

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,12 @@ exclude_patterns = [
5050
'onnxscript/optimizer/_legacy/constant_folding.py', # FIXME
5151
'onnxscript/rewriter/onnxruntime/transformers/fastgelu.py', # FIXME
5252
'onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py', # FIXME
53-
'onnxscript/rewriter/ort_fusions/_smollm_*.py', # onnxscript code
53+
'onnxscript/rewriter/ort_fusions/models/*.py', # onnxscript code
54+
'onnxscript/rewriter/ort_fusions/models/_phi2lm.py', # onnxscript code
55+
'onnxscript/rewriter/ort_fusions/models/_phi4lm.py', # onnxscript code
5456
'onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py', # onnxscript code
55-
'onnxscript/_legacy_ir/irbuilder.py', # FIXME
5657
'onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py', # FIXME
5758
'onnxscript/tools/function_unittest_producer.py', # FIXME
58-
'onnxscript/_legacy_ir/visitor.py', # FIXME
5959
'onnxscript/rewriter/onnxruntime/transformers/layernorm.py', # FIXME
6060
'onnxscript/rewriter/generic_pattern.py', # FIXME
6161
]
@@ -114,16 +114,14 @@ include_patterns = [
114114
'**/*.py',
115115
]
116116
exclude_patterns = [
117-
'examples/**', # TODO: Merge with docs/examples
118-
'docs/examples/**',
119-
'docs/tutorial/examples/**',
117+
'examples/**',
118+
'docs/**',
120119
'onnxscript/converter_test.py',
121120
'tests/functions/**',
122121
'tests/models/**',
123122
'tests/onnx_backend_test_code/**',
124123
'onnxscript/optimizer/**', # FIXME
125124
'onnxscript/rewriter/**', # FIXME
126-
'onnxscript/_legacy_ir/**', # FIXME
127125
]
128126
command = [
129127
'python',

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.3.0
1+
0.4.0

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
"matplotlib": ("https://matplotlib.org/stable/", None),
9090
"numpy": ("https://numpy.org/doc/stable/", None),
9191
"onnx": ("https://onnx.ai/onnx/", None),
92+
"onnx_ir": ("https://onnx.ai/ir-py/", None),
9293
"onnxruntime": ("https://onnxruntime.ai/docs/api/python/", None),
9394
"scipy": ("https://docs.scipy.org/doc/scipy/", None),
9495
"torch": ("https://pytorch.org/docs/main/", None),

docs/tutorial/rewriter/conditional_rewrite.md

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ Similarly for writing the condition checking function, we require only `input_a`
3232
:::
3333

3434
In order to validate whether matmul broadcast is sufficient, we write a condition checking function as below.
35-
Note that the relevant inputs passed to the check function are all instances of :class:`onnx_ir.Value`. These represent
35+
Note that the relevant inputs passed to the check function are all instances of {py:class}`onnx_ir.Value`. These represent
3636
the values in the input graph IR that matched against the corresponding _pattern variables_ in the target
3737
pattern. Please see documentation of the [IR API](https://onnx.ai/ir-py/) for more details on how to use it, for example to identify
3838
the type or shape or rank of these values.
@@ -51,3 +51,53 @@ The final graph with the applied rewrite looks as follows:
5151

5252
![broadcast_rewrite](examples/img/broadcast_02.png){align=center}
5353

54+
# Using MatchContext for Advanced Condition Checking
55+
56+
The `context` parameter passed to condition functions is an instance of {py:class}`onnxscript.rewriter.MatchContext`, which provides access to additional information about the pattern match that can be useful for sophisticated condition checking.
57+
58+
## MatchContext Properties
59+
60+
The MatchContext provides the following read-only properties:
61+
62+
- `model`: The entire ONNX model being matched
63+
- `graph_or_function`: The specific graph or function being matched
64+
- `root`: The root node of the matching subgraph
65+
- `output_values`: The output values of the matching subgraph
66+
- `nodes`: All nodes that are part of the matching subgraph
67+
68+
## Example Usage
69+
70+
Here's an example showing how to use the MatchContext to implement more sophisticated condition checking:
71+
72+
```python
73+
def advanced_condition_check(context, x, y, **_):
74+
"""Example condition function using MatchContext."""
75+
76+
# Access the main node of the pattern match
77+
main_node = context.root
78+
79+
# Check that the main_node does not have an attribute called "alpha"
80+
if "alpha" in main_node.attributes:
81+
return False
82+
83+
# Access the broader graph context and check that x occurs as a graph-input
84+
model = context.model
85+
if x not in model.graph.inputs:
86+
return False
87+
88+
# You can inspect the matched nodes for advanced validation
89+
for node in context.nodes:
90+
if node.op_type == "Constant":
91+
# Check properties of constant nodes in the match
92+
pass
93+
94+
# Access output values for shape/type validation
95+
outputs = context.output_values
96+
if len(outputs) > 0 and outputs[0].shape is not None:
97+
# Validate output shapes
98+
pass
99+
100+
return True
101+
```
102+
103+
This context information enables condition functions to make decisions based on the broader graph structure, the specific nodes involved in the match, and relationships between matched patterns and the rest of the model.

docs/tutorial/rewriter/examples/broadcast_matmul.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,6 @@ def check_if_not_need_reshape(
7979
Returns:
8080
True if we need to replace the pattern, False otherwise.
8181
"""
82-
del context # Reserved for future extensions
83-
8482
input_a_shape = input_a.shape
8583
input_b_shape = input_b.shape
8684
shape_c_tensor = shape_c.const_value
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
(heading-target-checkers)=
2+
# Node and Value Level Checkers
3+
4+
The pattern matching infrastructure supports custom validation logic at both the node and value levels through checker functions. These checkers allow for more sophisticated pattern matching by enabling additional constraints beyond basic operator and structure matching.
5+
6+
## Value-Level Checkers
7+
8+
Value-level checkers validate properties of specific values in the pattern. They are particularly useful for checking constants, shapes, or other value-specific properties.
9+
10+
### Basic Usage
11+
12+
A value checker is a function that takes a `MatchContext` and an `ir.Value`, and returns either a boolean or a `MatchResult`:
13+
14+
```python
15+
def is_positive_constant(context, value: ir.Value):
16+
"""Check if a value is a positive constant."""
17+
if value.const_value is not None:
18+
# Get the numpy array from const_value
19+
numpy_array = value.const_value.numpy()
20+
21+
# Check if it represents a single value and is positive
22+
if numpy_array.size != 1:
23+
return False
24+
25+
return float(numpy_array.item()) > 0
26+
27+
return False
28+
```
29+
30+
You can use this checker directly in your pattern by passing the callable as an input:
31+
32+
```python
33+
def add_pattern(op, x, y):
34+
# Use callable as input to create ValuePattern with checker
35+
return op.Add(is_positive_constant, y)
36+
```
37+
38+
This pattern will only match `Add` operations where the first input is a positive constant value.
39+
40+
### Example Usage
41+
42+
```python
43+
from onnxscript.rewriter import pattern
44+
from onnxscript import ir, optimizer
45+
import onnx
46+
47+
# Create a model with different Add operations
48+
model_proto = onnx.parser.parse_model("""
49+
<ir_version: 7, opset_import: [ "" : 17]>
50+
agraph (float[N] x, float[N] y) => (float[N] z1, float[N] z2, float[N] z3)
51+
{
52+
pos_const = Constant <value_float = 2.5> ()
53+
neg_const = Constant <value_float = -1.5> ()
54+
z1 = Add(x, y) # non-constant first parameter
55+
z2 = Add(pos_const, y) # positive constant first parameter
56+
z3 = Add(neg_const, y) # negative constant first parameter
57+
}
58+
""")
59+
model = ir.serde.deserialize_model(model_proto)
60+
61+
# Apply constant propagation to set const_value fields
62+
optimizer.basic_constant_propagation(model.graph.all_nodes())
63+
64+
# Create the pattern with value checker
65+
rule_pattern = pattern.Pattern(add_pattern)
66+
67+
# Test matching against different Add nodes
68+
add_nodes = [node for node in model.graph if node.op_type == "Add"]
69+
70+
# Non-constant first parameter - will not match
71+
match_result = rule_pattern.match(model, model.graph, add_nodes[0])
72+
print(f"Non-constant: {bool(match_result)}") # False
73+
74+
# Positive constant first parameter - will match
75+
match_result = rule_pattern.match(model, model.graph, add_nodes[1])
76+
print(f"Positive constant: {bool(match_result)}") # True
77+
78+
# Negative constant first parameter - will not match
79+
match_result = rule_pattern.match(model, model.graph, add_nodes[2])
80+
print(f"Negative constant: {bool(match_result)}") # False
81+
```
82+
83+
## Node-Level Checkers
84+
85+
Node-level checkers validate properties of the operation nodes themselves, such as attributes, operation types, or other node-specific properties.
86+
87+
### Basic Usage
88+
89+
A node checker is a function that takes a `MatchContext` and an `ir.Node`, and returns either a boolean or a `MatchResult`:
90+
91+
```python
92+
def shape_node_checker(context, node):
93+
"""Check if a Shape operation has start attribute equal to 0."""
94+
return node.attributes.get_int("start", 0) == 0
95+
```
96+
97+
You can use this checker by passing it to the `_check` parameter of an operation:
98+
99+
```python
100+
def shape_pattern(op, x):
101+
return op.Shape(x, _check=shape_node_checker)
102+
```
103+
104+
This pattern will only match `Shape` operations where the `start` attribute is 0 (or not present, as the default is 0).
105+
106+
### Example Usage
107+
108+
```python
109+
from onnxscript.rewriter import pattern
110+
from onnxscript import ir
111+
import onnx
112+
113+
# Create a model with different Shape operations
114+
model_proto = onnx.parser.parse_model("""
115+
<ir_version: 7, opset_import: [ "" : 17]>
116+
agraph (float[N, M] x) => (int64[2] z1, int64[2] z2, int64[1] z3)
117+
{
118+
z1 = Shape(x)
119+
z2 = Shape <start: int = 0>(x)
120+
z3 = Shape <start: int = 1>(x)
121+
}
122+
""")
123+
model = ir.serde.deserialize_model(model_proto)
124+
125+
# Create the pattern with node checker
126+
rule_pattern = pattern.Pattern(shape_pattern)
127+
128+
# Test matching against different Shape nodes
129+
nodes = list(model.graph)
130+
shape_nodes = [node for node in nodes if node.op_type == "Shape"]
131+
132+
# Shape without start attribute (default 0) - will match
133+
match_result = rule_pattern.match(model, model.graph, shape_nodes[0])
134+
print(f"No start attr: {bool(match_result)}") # True
135+
136+
# Shape with start=0 - will match
137+
match_result = rule_pattern.match(model, model.graph, shape_nodes[1])
138+
print(f"Start=0: {bool(match_result)}") # True
139+
140+
# Shape with start=1 - will not match
141+
match_result = rule_pattern.match(model, model.graph, shape_nodes[2])
142+
print(f"Start=1: {bool(match_result)}") # False
143+
```
144+
145+
## Combining Checkers
146+
147+
You can combine both node-level and value-level checkers in the same pattern for more sophisticated matching:
148+
149+
```python
150+
def complex_pattern(op, x, y):
151+
# Value-level checker for first input
152+
validated_x = is_positive_constant
153+
# Node-level checker for the operation
154+
return op.Add(validated_x, y, _check=lambda ctx, node: len(node.attributes) == 0)
155+
```
156+
157+
This pattern will only match `Add` operations where:
158+
1. The first input is a positive constant (value-level check)
159+
2. The node has no custom attributes (node-level check)
160+
161+
## Execution Timing and Limitations
162+
163+
### When Checkers Are Called
164+
165+
Node-level and value-level checkers are called **only at the end of the complete structural match**. This means:
166+
167+
1. **Structural matching happens first**: The pattern matching engine first validates that the graph structure matches the pattern (correct operators, connections, etc.)
168+
2. **Checkers run after structural validation**: Only after the structural match succeeds do the node and value checkers execute
169+
3. **Order of execution**: Value-level checkers run first, followed by node-level checkers, and finally the pattern's condition function
170+
171+
### Limitations with Pattern Disjunctions
172+
173+
One important limitation of this design is that these checks don't compose well with pattern disjunctions (multiple alternative patterns). When searching among multiple value patterns:
174+
175+
- **Only structural checking is performed initially**: If structural matching succeeds for the first alternative, other alternatives are not considered
176+
- **Checker failures don't trigger backtracking**: If a checker fails, the entire pattern match fails rather than trying the next alternative pattern
177+
178+
This means you should be careful when designing patterns with multiple alternatives that rely on checkers, as the checker logic may prevent exploration of valid alternative matches.
179+
180+
## Error Handling
181+
182+
Checkers can return either:
183+
- `True`: Check passed, continue matching
184+
- `False`: Check failed, pattern does not match
185+
- `MatchResult`: More detailed result with potential failure reasons
186+
187+
If a checker raises an exception, it will be caught and treated as a match failure, allowing patterns to fail gracefully when encountering unexpected conditions.

docs/tutorial/rewriter/rewrite_patterns.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,6 @@ These options are documented in detail in the following sections.
4444

4545
```{include} commute.md
4646
```
47+
48+
```{include} node_value_checkers.md
49+
```

docs/tutorial/rewriter/simple_example.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ After this, create a replacement pattern that consists of the GELU onnxscript op
3333
:::{note}
3434
:name: type annotate ir.Value
3535

36-
The inputs to the replacement pattern are of type `ir.Value`. For detailed usage of `ir.Value` refer to the {py:class}`ir.Value <onnxscript.ir._core.Value>` class.
36+
The inputs to the replacement pattern are of type `ir.Value`. For detailed usage of `ir.Value` refer to the {py:class}`ir.Value <onnx_ir.Value>` class.
3737
:::
3838

3939

0 commit comments

Comments
 (0)