Skip to content

Commit bf15d0c

Browse files
Copilotgramalingam
andcommitted
Update node checker test to use Shape operations with start attribute validation
Co-authored-by: gramalingam <10075881+gramalingam@users.noreply.github.com>
1 parent 8f6c474 commit bf15d0c

File tree

1 file changed

+28
-16
lines changed

1 file changed

+28
-16
lines changed

onnxscript/rewriter/pattern_test.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -786,39 +786,51 @@ class ValueNodeCheckersTest(unittest.TestCase):
786786
def test_pattern_match_with_node_checker(self):
787787
"""Test Pattern.match with node-level checker."""
788788

789-
def add_node_checker(context, node):
790-
return node.op_type == "Add"
789+
def shape_node_checker(context, node):
790+
return node.attributes.get_int("start", 0) == 0
791791

792-
# Create a pattern that matches Add operations with a node checker
793-
def add_pattern(op, x, y):
794-
return op.Add(x, y, _check=add_node_checker)
792+
# Create a pattern that matches Shape operations with a node checker
793+
def shape_pattern(op, x):
794+
return op.Shape(x, _check=shape_node_checker)
795795

796796
# Create the pattern
797-
rule_pattern = pattern.Pattern(add_pattern)
797+
rule_pattern = pattern.Pattern(shape_pattern)
798798

799-
# Create a simple model
799+
# Create a model with multiple Shape nodes with different start attributes
800800
model_proto = onnx.parser.parse_model(
801801
"""
802802
<ir_version: 7, opset_import: [ "" : 17]>
803-
agraph (float[N] x, float[N] y) => (float[N] z)
803+
agraph (float[N, M] x) => (int64[2] z1, int64[2] z2, int64[1] z3)
804804
{
805-
z = Add(x, y)
805+
z1 = Shape(x)
806+
z2 = Shape <start: int = 0>(x)
807+
z3 = Shape <start: int = 1>(x)
806808
}
807809
"""
808810
)
809811
model = ir.serde.deserialize_model(model_proto)
810812

811-
# Find the Add node in the model
813+
# Find the Shape nodes in the model
812814
nodes = list(model.graph)
813-
add_node = nodes[0]
814-
self.assertEqual(add_node.op_type, "Add")
815+
shape_node_no_attr = nodes[0] # Shape without start attribute
816+
shape_node_start_0 = nodes[1] # Shape with start=0
817+
shape_node_start_1 = nodes[2] # Shape with start=1
815818

816-
# Try to match the pattern
817-
match_result = rule_pattern.match(model, model.graph, add_node)
819+
self.assertEqual(shape_node_no_attr.op_type, "Shape")
820+
self.assertEqual(shape_node_start_0.op_type, "Shape")
821+
self.assertEqual(shape_node_start_1.op_type, "Shape")
818822

823+
# Test case 1: Shape without start attribute (should match, default is 0)
824+
match_result = rule_pattern.match(model, model.graph, shape_node_no_attr)
819825
self.assertTrue(bool(match_result))
820-
self.assertEqual(len(match_result.nodes), 1)
821-
self.assertEqual(len(match_result.node_bindings), 1)
826+
827+
# Test case 2: Shape with start=0 (should match)
828+
match_result = rule_pattern.match(model, model.graph, shape_node_start_0)
829+
self.assertTrue(bool(match_result))
830+
831+
# Test case 3: Shape with start=1 (should not match)
832+
match_result = rule_pattern.match(model, model.graph, shape_node_start_1)
833+
self.assertFalse(bool(match_result))
822834

823835
def test_pattern_match_with_failing_node_checker(self):
824836
"""Test Pattern.match with failing node-level checker."""

0 commit comments

Comments
 (0)