@@ -989,5 +989,247 @@ def test_pattern_builder_context(self):
989989 self .assertEqual (ops , ["Op1" , "Op2" , "Add" , "Op3" , "Mul" ])
990990
991991
992+ class FlexibleOutputTest (unittest .TestCase ):
993+ """Test patterns with flexible output counts using _allow_flexible_outputs."""
994+
995+ def test_flexible_outputs_with_split (self ):
996+ """Test that _allow_flexible_outputs works for Split with varying output counts."""
997+
998+ def relu_split_pattern (op , x ):
999+ relu = op .Relu (x )
1000+ return op .Split (relu , _allow_flexible_outputs = True )
1001+
1002+ def relu_split_rewrite (op , _match = None , x = None ):
1003+ if x is None or _match is None :
1004+ return None
1005+
1006+ split = next ((n for n in _match .nodes if n .op_type == "Split" ), None )
1007+ if not split :
1008+ return None
1009+
1010+ num_outputs = len (split .outputs )
1011+ split_results = op .Split (x , _outputs = num_outputs , ** split .attributes )
1012+
1013+ return tuple (op .Relu (s ) for s in split_results ) if num_outputs > 1 else op .Relu (split_results )
1014+
1015+ rule = pattern .RewriteRule (relu_split_pattern , relu_split_rewrite )
1016+
1017+ # Test model with Relu -> Split pattern
1018+ model_proto = onnx .parser .parse_model (
1019+ """
1020+ <ir_version: 7, opset_import: [ "" : 18]>
1021+ agraph (float[10] x) => (float[5] out1, float[5] out2)
1022+ {
1023+ relu_out = Relu(x)
1024+ out1, out2 = Split<axis=0, num_outputs=2>(relu_out)
1025+ }
1026+ """
1027+ )
1028+
1029+ optimized = onnxscript .rewriter .rewrite (model_proto , pattern_rewrite_rules = [rule ])
1030+
1031+ # Verify transformation: 1 Relu + 1 Split -> 1 Split + 2 Relu
1032+ def count_ops (proto , op_type ):
1033+ return sum (1 for n in proto .graph .node if n .op_type == op_type )
1034+
1035+ self .assertEqual (count_ops (model_proto , "Relu" ), 1 )
1036+ self .assertEqual (count_ops (model_proto , "Split" ), 1 )
1037+ self .assertEqual (count_ops (optimized , "Relu" ), 2 )
1038+ self .assertEqual (count_ops (optimized , "Split" ), 1 )
1039+
1040+ def test_flexible_outputs_without_match_parameter (self ):
1041+ def pattern_func (op , x ):
1042+ return op .Split (x , _allow_flexible_outputs = True )
1043+
1044+ def rewrite_func (op , x = None ):
1045+ if x is None :
1046+ return None
1047+ return op .Split (op .Relu (x ), _outputs = 2 )
1048+
1049+ rule = pattern .RewriteRule (pattern_func , rewrite_func )
1050+
1051+ model_proto = onnx .parser .parse_model (
1052+ """
1053+ <ir_version: 7, opset_import: [ "" : 18]>
1054+ agraph (float[10] x) => (float[5] out1, float[5] out2)
1055+ {
1056+ out1, out2 = Split<axis=0, num_outputs=2>(x)
1057+ }
1058+ """
1059+ )
1060+
1061+ optimized = onnxscript .rewriter .rewrite (model_proto , pattern_rewrite_rules = [rule ])
1062+
1063+ def count_ops (proto , op_type ):
1064+ return sum (1 for n in proto .graph .node if n .op_type == op_type )
1065+
1066+ self .assertEqual (count_ops (optimized , "Relu" ), 1 )
1067+ self .assertEqual (count_ops (optimized , "Split" ), 1 )
1068+
1069+ def test_output_count_validation_without_flexible (self ):
1070+ def pattern_func (op , x ):
1071+ return op .Relu (x )
1072+
1073+ def bad_rewrite_func (op , x = None ):
1074+ return op .Split (x , _outputs = 2 )
1075+
1076+ rule = pattern .RewriteRule (pattern_func , bad_rewrite_func )
1077+
1078+ model_proto = onnx .parser .parse_model (
1079+ """
1080+ <ir_version: 7, opset_import: [ "" : 18]>
1081+ agraph (float[10] x) => (float[10] out)
1082+ {
1083+ out = Relu(x)
1084+ }
1085+ """
1086+ )
1087+
1088+ with self .assertRaises (ValueError ) as ctx :
1089+ onnxscript .rewriter .rewrite (model_proto , pattern_rewrite_rules = [rule ])
1090+
1091+ self .assertIn ("Number of outputs" , str (ctx .exception ))
1092+
1093+ def test_standard_replacement_path (self ):
1094+ def pattern_func (op , x ):
1095+ return op .Relu (x )
1096+
1097+ def rewrite_func (op , x = None ):
1098+ return op .Sigmoid (x )
1099+
1100+ rule = pattern .RewriteRule (pattern_func , rewrite_func )
1101+
1102+ model_proto = onnx .parser .parse_model (
1103+ """
1104+ <ir_version: 7, opset_import: [ "" : 18]>
1105+ agraph (float[10] x) => (float[10] out)
1106+ {
1107+ out = Relu(x)
1108+ }
1109+ """
1110+ )
1111+
1112+ optimized = onnxscript .rewriter .rewrite (model_proto , pattern_rewrite_rules = [rule ])
1113+
1114+ def count_ops (proto , op_type ):
1115+ return sum (1 for n in proto .graph .node if n .op_type == op_type )
1116+
1117+ self .assertEqual (count_ops (optimized , "Relu" ), 0 )
1118+ self .assertEqual (count_ops (optimized , "Sigmoid" ), 1 )
1119+
1120+ def test_flexible_outputs_with_three_outputs (self ):
1121+ def pattern_func (op , x ):
1122+ return op .Split (x , _allow_flexible_outputs = True )
1123+
1124+ def rewrite_func (op , _match = None , x = None ):
1125+ if x is None or _match is None :
1126+ return None
1127+
1128+ split = next ((n for n in _match .nodes if n .op_type == "Split" ), None )
1129+ if not split :
1130+ return None
1131+
1132+ num_outputs = len (split .outputs )
1133+ relu = op .Relu (x )
1134+ split_results = op .Split (relu , _outputs = num_outputs , ** split .attributes )
1135+ return split_results
1136+
1137+ rule = pattern .RewriteRule (pattern_func , rewrite_func )
1138+
1139+ model_proto = onnx .parser .parse_model (
1140+ """
1141+ <ir_version: 7, opset_import: [ "" : 18]>
1142+ agraph (float[15] x) => (float[5] out1, float[5] out2, float[5] out3)
1143+ {
1144+ out1, out2, out3 = Split<axis=0, num_outputs=3>(x)
1145+ }
1146+ """
1147+ )
1148+
1149+ optimized = onnxscript .rewriter .rewrite (model_proto , pattern_rewrite_rules = [rule ])
1150+
1151+ def count_ops (proto , op_type ):
1152+ return sum (1 for n in proto .graph .node if n .op_type == op_type )
1153+
1154+ self .assertEqual (count_ops (optimized , "Relu" ), 1 )
1155+ self .assertEqual (count_ops (optimized , "Split" ), 1 )
1156+
1157+ def test_flexible_outputs_with_single_output (self ):
1158+ def pattern_func (op , x ):
1159+ return op .Split (x , _allow_flexible_outputs = True )
1160+
1161+ def rewrite_func (op , _match = None , x = None ):
1162+ if x is None or _match is None :
1163+ return None
1164+
1165+ split = next ((n for n in _match .nodes if n .op_type == "Split" ), None )
1166+ if not split :
1167+ return None
1168+
1169+ relu = op .Relu (x )
1170+ if len (split .outputs ) == 1 :
1171+ return op .Split (relu , _outputs = 1 , ** split .attributes )
1172+ return relu
1173+
1174+ rule = pattern .RewriteRule (pattern_func , rewrite_func )
1175+
1176+ model_proto = onnx .parser .parse_model (
1177+ """
1178+ <ir_version: 7, opset_import: [ "" : 18]>
1179+ agraph (float[10] x) => (float[10] out)
1180+ {
1181+ out = Split<axis=0>(x)
1182+ }
1183+ """
1184+ )
1185+
1186+ optimized = onnxscript .rewriter .rewrite (model_proto , pattern_rewrite_rules = [rule ])
1187+
1188+ def count_ops (proto , op_type ):
1189+ return sum (1 for n in proto .graph .node if n .op_type == op_type )
1190+
1191+ self .assertEqual (count_ops (optimized , "Relu" ), 1 )
1192+ self .assertEqual (count_ops (optimized , "Split" ), 1 )
1193+
1194+ def test_flexible_outputs_with_partial_usage (self ):
1195+ def pattern_func (op , x ):
1196+ return op .Split (x , _allow_flexible_outputs = True )
1197+
1198+ def rewrite_func (op , _match = None , x = None ):
1199+ if x is None or _match is None :
1200+ return None
1201+
1202+ split = next ((n for n in _match .nodes if n .op_type == "Split" ), None )
1203+ if not split :
1204+ return None
1205+
1206+ num_outputs = len (split .outputs )
1207+ relu = op .Relu (x )
1208+ return op .Split (relu , _outputs = num_outputs , ** split .attributes )
1209+
1210+ rule = pattern .RewriteRule (pattern_func , rewrite_func )
1211+
1212+ model_proto = onnx .parser .parse_model (
1213+ """
1214+ <ir_version: 7, opset_import: [ "" : 18]>
1215+ agraph (float[10] x) => (float[5] out1, float[5] out2, float[5] sum)
1216+ {
1217+ s1, s2 = Split<axis=0, num_outputs=2>(x)
1218+ out1 = Abs(s1)
1219+ out2 = Neg(s2)
1220+ sum = Add(out1, out2)
1221+ }
1222+ """
1223+ )
1224+
1225+ optimized = onnxscript .rewriter .rewrite (model_proto , pattern_rewrite_rules = [rule ])
1226+
1227+ def count_ops (proto , op_type ):
1228+ return sum (1 for n in proto .graph .node if n .op_type == op_type )
1229+
1230+ self .assertEqual (count_ops (optimized , "Relu" ), 1 )
1231+ self .assertEqual (count_ops (optimized , "Split" ), 1 )
1232+
1233+
9921234if __name__ == "__main__" :
9931235 unittest .main ()
0 commit comments