@@ -27,7 +27,7 @@ def run_and_compare(self, output_names_with_port, onnx_feed_dict, origin_proto,
27
27
28
28
origin_model_path = self .save_onnx_model (origin_proto , onnx_feed_dict , postfix = "_origin" )
29
29
30
- new_proto = GraphUtil .optimize_graph_with_model_proto (origin_proto )
30
+ new_proto = GraphUtil .optimize_model_proto (origin_proto )
31
31
32
32
self .assertTrue (new_proto , msg = "model proto after optimizer should not be None" )
33
33
@@ -287,7 +287,115 @@ def test_identity_in_subgraph_non_graph_output(self):
287
287
self .run_identity_compare (["Z1" ], {"X" : np .random .randn (2 , 3 , 4 , 5 ).astype (np .float32 )},
288
288
model_proto , remaining_identity_num = 0 )
289
289
290
- # Tranpose Optimizer Tests End
290
+ # Identity Optimizer Tests End
291
+
292
+ # Merge Duplicated Nodes Optimizer Tests Start
293
+
294
+ def run_merge_duplicated_nodes_compare (self , output_names_with_port , onnx_feed_dict , origin_proto ,
295
+ op_type = None , remaining_op_num = None , debug = False , rtol = 1e-07 ):
296
+ self .run_and_compare (output_names_with_port , onnx_feed_dict , origin_proto , op_type = op_type ,
297
+ remaining_op_num = remaining_op_num , debug = debug , rtol = rtol )
298
+
299
+ def test_duplicated_duplicated_input (self ):
300
+ # same input or not
301
+ node0 = helper .make_node ('Add' , inputs = ["X" , "X" ], outputs = ["value0" ])
302
+ node1 = helper .make_node ('Add' , inputs = ["X" , "X" ], outputs = ["value1" ])
303
+ node2 = helper .make_node ('Add' , inputs = ["value1" , "X" ], outputs = ["value2" ])
304
+ node3 = helper .make_node ("Mul" , ["value0" , "value2" ], ["value3" ])
305
+ node4 = helper .make_node ("Mul" , ["value1" , "value3" ], ["OUT" ])
306
+
307
+ graph = helper .make_graph (
308
+ [node0 , node1 , node2 , node3 , node4 ],
309
+ "test_duplicated_duplicated_input" ,
310
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (5 , 5 ))],
311
+ [helper .make_tensor_value_info ("OUT" , TensorProto .FLOAT , (5 , 5 ))],
312
+ )
313
+
314
+ model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
315
+ self .run_merge_duplicated_nodes_compare (["OUT" ], {"X" : np .random .randn (5 , 5 ).astype (np .float32 )}, model_proto ,
316
+ op_type = "Add" , remaining_op_num = 2 )
317
+
318
+ def test_duplicated_duplicated_attributes (self ):
319
+ # same attr or not
320
+ node0 = helper .make_node ('ReduceSum' , inputs = ["X" ], outputs = ["value0" ], axes = [0 ], keepdims = 0 )
321
+ node1 = helper .make_node ('ReduceSum' , inputs = ["X" ], outputs = ["value1" ], axes = [0 ], keepdims = 0 )
322
+ node2 = helper .make_node ('ReduceSum' , inputs = ["X" ], outputs = ["value2" ], axes = [1 ], keepdims = 0 )
323
+ node3 = helper .make_node ('Add' , inputs = ["value0" , "value1" ], outputs = ["value3" ])
324
+ node4 = helper .make_node ("Mul" , ["value2" , "value3" ], ["OUT" ])
325
+
326
+ graph = helper .make_graph (
327
+ [node0 , node1 , node2 , node3 , node4 ],
328
+ "test_duplicated_duplicated_attributes" ,
329
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (5 , 5 ))],
330
+ [helper .make_tensor_value_info ("OUT" , TensorProto .FLOAT , (5 ,))],
331
+ )
332
+
333
+ model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
334
+ self .run_merge_duplicated_nodes_compare (["OUT" ], {"X" : np .random .randn (5 , 5 ).astype (np .float32 )}, model_proto ,
335
+ op_type = "ReduceSum" , remaining_op_num = 2 )
336
+
337
+ def test_duplicated_node_is_graph_output (self ):
338
+ node0 = helper .make_node ('Add' , inputs = ["X" , "X" ], outputs = ["value0" ])
339
+ node1 = helper .make_node ('Add' , inputs = ["X" , "X" ], outputs = ["value1" ])
340
+ node2 = helper .make_node ('Add' , inputs = ["value1" , "X" ], outputs = ["value2" ])
341
+
342
+ graph = helper .make_graph (
343
+ [node0 , node1 , node2 ],
344
+ "test_duplicated_node_is_graph_output" ,
345
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (5 , 5 ))],
346
+ [helper .make_tensor_value_info ("value1" , TensorProto .FLOAT , (5 , 5 )),
347
+ helper .make_tensor_value_info ("value2" , TensorProto .FLOAT , (5 , 5 ))],
348
+ )
349
+
350
+ model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
351
+ self .run_merge_duplicated_nodes_compare (["value1" , "value2" ],
352
+ {"X" : np .random .randn (5 , 5 ).astype (np .float32 )}, model_proto ,
353
+ op_type = "Add" , remaining_op_num = 2 )
354
+
355
+ def test_duplicated_different_output_length (self ):
356
+ node0 = helper .make_node ('Dropout' , inputs = ["X" ], outputs = ["value0" ])
357
+ node1 = helper .make_node ('Dropout' , inputs = ["X" ], outputs = ["value1" , "mask" ])
358
+ node2 = helper .make_node ('Dropout' , inputs = ["value1" ], outputs = ["value2" ])
359
+
360
+ graph = helper .make_graph (
361
+ [node0 , node1 , node2 ],
362
+ "test_duplicated_different_output_length" ,
363
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (5 ,))],
364
+ [helper .make_tensor_value_info ("value1" , TensorProto .FLOAT , (5 ,)),
365
+ helper .make_tensor_value_info ("mask" , TensorProto .BOOL , (5 ,)),
366
+ helper .make_tensor_value_info ("value2" , TensorProto .FLOAT , (5 ,))],
367
+ )
368
+
369
+ model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
370
+ self .run_merge_duplicated_nodes_compare (["value1" , "mask" , "value2" ],
371
+ {"X" : np .random .randn (5 ,).astype (np .float32 )},
372
+ model_proto ,
373
+ op_type = "Dropout" , remaining_op_num = 2 )
374
+
375
+ def test_duplicated_need_multiple_run (self ):
376
+ node00 = helper .make_node ('Log' , inputs = ["X" ], outputs = ["value00" ])
377
+ node01 = helper .make_node ('Log' , inputs = ["value00" ], outputs = ["value01" ])
378
+ node02 = helper .make_node ('Log' , inputs = ["value01" ], outputs = ["value02" ])
379
+
380
+ node10 = helper .make_node ('Log' , inputs = ["X" ], outputs = ["value10" ])
381
+ node11 = helper .make_node ('Log' , inputs = ["value10" ], outputs = ["value11" ])
382
+ node12 = helper .make_node ('Log' , inputs = ["value11" ], outputs = ["value12" ])
383
+
384
+ res = helper .make_node ('Add' , inputs = ["value02" , "value12" ], outputs = ["res" ])
385
+
386
+ graph = helper .make_graph (
387
+ [node00 , node01 , node02 , node10 , node11 , node12 , res ],
388
+ "test_duplicated_node_is_graph_output" ,
389
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (5 ,))],
390
+ [helper .make_tensor_value_info ("res" , TensorProto .FLOAT , (5 ,))],
391
+ )
392
+
393
+ model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
394
+ self .run_merge_duplicated_nodes_compare (["res" ], {"X" : np .random .randn (5 ,).astype (np .float32 )},
395
+ model_proto ,
396
+ op_type = "Log" , remaining_op_num = 3 )
397
+ # Merge Duplicated Nodes Optimizer Tests End
398
+
291
399
292
400
if __name__ == "__main__" :
293
401
unittest_main ()
0 commit comments