@@ -27,7 +27,7 @@ def run_and_compare(self, output_names_with_port, onnx_feed_dict, origin_proto,
27
27
28
28
origin_model_path = self .save_onnx_model (origin_proto , onnx_feed_dict , postfix = "_origin" )
29
29
30
- new_proto = GraphUtil .optimize_graph_with_model_proto (origin_proto )
30
+ new_proto = GraphUtil .optimize_model_proto (origin_proto )
31
31
32
32
self .assertTrue (new_proto , msg = "model proto after optimizer should not be None" )
33
33
@@ -287,7 +287,54 @@ def test_identity_in_subgraph_non_graph_output(self):
287
287
self .run_identity_compare (["Z1" ], {"X" : np .random .randn (2 , 3 , 4 , 5 ).astype (np .float32 )},
288
288
model_proto , remaining_identity_num = 0 )
289
289
290
- # Tranpose Optimizer Tests End
290
+ # Identity Optimizer Tests End
291
+
292
+ # Merge Duplicated Nodes Optimizer Tests Start
293
+
294
+ def run_merge_duplicated_nodes_compare (self , output_names_with_port , onnx_feed_dict , origin_proto ,
295
+ op_type = None , remaining_op_num = None , debug = False , rtol = 1e-07 ):
296
+ self .run_and_compare (output_names_with_port , onnx_feed_dict , origin_proto , op_type = op_type ,
297
+ remaining_op_num = remaining_op_num , debug = debug , rtol = rtol )
298
+
299
+ def test_duplicated_duplicated_input (self ):
300
+ # same input or not
301
+ node0 = helper .make_node ('Add' , inputs = ["X" , "X" ], outputs = ["value0" ])
302
+ node1 = helper .make_node ('Add' , inputs = ["X" , "X" ], outputs = ["value1" ])
303
+ node2 = helper .make_node ('Add' , inputs = ["value1" , "X" ], outputs = ["value2" ])
304
+ node3 = helper .make_node ("Mul" , ["value0" , "value2" ], ["value3" ])
305
+ node4 = helper .make_node ("Mul" , ["value1" , "value3" ], ["OUT" ])
306
+
307
+ graph = helper .make_graph (
308
+ [node0 , node1 , node2 , node3 , node4 ],
309
+ "transpose-merge-test" ,
310
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (5 , 5 ))],
311
+ [helper .make_tensor_value_info ("OUT" , TensorProto .FLOAT , (5 , 5 ))],
312
+ )
313
+
314
+ model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
315
+ self .run_merge_duplicated_nodes_compare (["OUT" ], {"X" : np .random .randn (5 , 5 ).astype (np .float32 )}, model_proto ,
316
+ op_type = "Add" , remaining_op_num = 2 )
317
+
318
+ def test_duplicated_duplicated_attributes (self ):
319
+ # same attr or not
320
+ node0 = helper .make_node ('ReduceSum' , inputs = ["X" ], outputs = ["value0" ], axes = [0 ], keepdims = 0 )
321
+ node1 = helper .make_node ('ReduceSum' , inputs = ["X" ], outputs = ["value1" ], axes = [0 ], keepdims = 0 )
322
+ node2 = helper .make_node ('ReduceSum' , inputs = ["X" ], outputs = ["value2" ], axes = [1 ], keepdims = 0 )
323
+ node3 = helper .make_node ('Add' , inputs = ["value0" , "value1" ], outputs = ["value3" ])
324
+ node4 = helper .make_node ("Mul" , ["value2" , "value3" ], ["OUT" ])
325
+
326
+ graph = helper .make_graph (
327
+ [node0 , node1 , node2 , node3 , node4 ],
328
+ "transpose-merge-test" ,
329
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (5 , 5 ))],
330
+ [helper .make_tensor_value_info ("OUT" , TensorProto .FLOAT , (5 ,))],
331
+ )
332
+
333
+ model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
334
+ self .run_merge_duplicated_nodes_compare (["OUT" ], {"X" : np .random .randn (5 , 5 ).astype (np .float32 )}, model_proto ,
335
+ op_type = "ReduceSum" , remaining_op_num = 2 )
336
+ # Merge Duplicated Nodes Optimizer Tests End
337
+
291
338
292
339
if __name__ == "__main__" :
293
340
unittest_main ()
0 commit comments