@@ -786,39 +786,51 @@ class ValueNodeCheckersTest(unittest.TestCase):
786786 def test_pattern_match_with_node_checker (self ):
787787 """Test Pattern.match with node-level checker."""
788788
789- def add_node_checker (context , node ):
790- return node .op_type == "Add"
789+ def shape_node_checker (context , node ):
790+ return node .attributes . get_int ( "start" , 0 ) == 0
791791
792- # Create a pattern that matches Add operations with a node checker
793- def add_pattern (op , x , y ):
794- return op .Add (x , y , _check = add_node_checker )
792+ # Create a pattern that matches Shape operations with a node checker
793+ def shape_pattern (op , x ):
794+ return op .Shape (x , _check = shape_node_checker )
795795
796796 # Create the pattern
797- rule_pattern = pattern .Pattern (add_pattern )
797+ rule_pattern = pattern .Pattern (shape_pattern )
798798
799- # Create a simple model
799+ # Create a model with multiple Shape nodes with different start attributes
800800 model_proto = onnx .parser .parse_model (
801801 """
802802 <ir_version: 7, opset_import: [ "" : 17]>
803- agraph (float[N] x, float[N] y ) => (float[N] z )
803+ agraph (float[N, M] x ) => (int64[2] z1, int64[2] z2, int64[1] z3 )
804804 {
805- z = Add(x, y)
805+ z1 = Shape(x)
806+ z2 = Shape <start: int = 0>(x)
807+ z3 = Shape <start: int = 1>(x)
806808 }
807809 """
808810 )
809811 model = ir .serde .deserialize_model (model_proto )
810812
811- # Find the Add node in the model
813+ # Find the Shape nodes in the model
812814 nodes = list (model .graph )
813- add_node = nodes [0 ]
814- self .assertEqual (add_node .op_type , "Add" )
815+ shape_node_no_attr = nodes [0 ] # Shape without start attribute
816+ shape_node_start_0 = nodes [1 ] # Shape with start=0
817+ shape_node_start_1 = nodes [2 ] # Shape with start=1
815818
816- # Try to match the pattern
817- match_result = rule_pattern .match (model , model .graph , add_node )
819+ self .assertEqual (shape_node_no_attr .op_type , "Shape" )
820+ self .assertEqual (shape_node_start_0 .op_type , "Shape" )
821+ self .assertEqual (shape_node_start_1 .op_type , "Shape" )
818822
823+ # Test case 1: Shape without start attribute (should match, default is 0)
824+ match_result = rule_pattern .match (model , model .graph , shape_node_no_attr )
819825 self .assertTrue (bool (match_result ))
820- self .assertEqual (len (match_result .nodes ), 1 )
821- self .assertEqual (len (match_result .node_bindings ), 1 )
826+
827+ # Test case 2: Shape with start=0 (should match)
828+ match_result = rule_pattern .match (model , model .graph , shape_node_start_0 )
829+ self .assertTrue (bool (match_result ))
830+
831+ # Test case 3: Shape with start=1 (should not match)
832+ match_result = rule_pattern .match (model , model .graph , shape_node_start_1 )
833+ self .assertFalse (bool (match_result ))
822834
823835 def test_pattern_match_with_failing_node_checker (self ):
824836 """Test Pattern.match with failing node-level checker."""
0 commit comments