@@ -306,7 +306,7 @@ def test_duplicated_duplicated_input(self):
306
306
307
307
graph = helper .make_graph (
308
308
[node0 , node1 , node2 , node3 , node4 ],
309
- "transpose-merge-test " ,
309
+ "test_duplicated_duplicated_input " ,
310
310
[helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (5 , 5 ))],
311
311
[helper .make_tensor_value_info ("OUT" , TensorProto .FLOAT , (5 , 5 ))],
312
312
)
@@ -325,14 +325,74 @@ def test_duplicated_duplicated_attributes(self):
325
325
326
326
graph = helper .make_graph (
327
327
[node0 , node1 , node2 , node3 , node4 ],
328
- "transpose-merge-test " ,
328
+ "test_duplicated_duplicated_attributes " ,
329
329
[helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (5 , 5 ))],
330
330
[helper .make_tensor_value_info ("OUT" , TensorProto .FLOAT , (5 ,))],
331
331
)
332
332
333
333
model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
334
334
self .run_merge_duplicated_nodes_compare (["OUT" ], {"X" : np .random .randn (5 , 5 ).astype (np .float32 )}, model_proto ,
335
335
op_type = "ReduceSum" , remaining_op_num = 2 )
336
+
337
+ def test_duplicated_node_is_graph_output (self ):
338
+ node0 = helper .make_node ('Add' , inputs = ["X" , "X" ], outputs = ["value0" ])
339
+ node1 = helper .make_node ('Add' , inputs = ["X" , "X" ], outputs = ["value1" ])
340
+ node2 = helper .make_node ('Add' , inputs = ["value1" , "X" ], outputs = ["value2" ])
341
+
342
+ graph = helper .make_graph (
343
+ [node0 , node1 , node2 ],
344
+ "test_duplicated_node_is_graph_output" ,
345
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (5 , 5 ))],
346
+ [helper .make_tensor_value_info ("value1" , TensorProto .FLOAT , (5 , 5 )),
347
+ helper .make_tensor_value_info ("value2" , TensorProto .FLOAT , (5 , 5 ))],
348
+ )
349
+
350
+ model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
351
+ self .run_merge_duplicated_nodes_compare (["value1" , "value2" ], {"X" : np .random .randn (5 , 5 ).astype (np .float32 )}, model_proto ,
352
+ op_type = "Add" , remaining_op_num = 2 )
353
+
354
+ def test_duplicated_different_output_length (self ):
355
+ node0 = helper .make_node ('Dropout' , inputs = ["X" ], outputs = ["value0" ])
356
+ node1 = helper .make_node ('Dropout' , inputs = ["X" ], outputs = ["value1" , "mask" ])
357
+ node2 = helper .make_node ('Dropout' , inputs = ["value1" ], outputs = ["value2" ])
358
+
359
+ graph = helper .make_graph (
360
+ [node0 , node1 , node2 ],
361
+ "test_duplicated_different_output_length" ,
362
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (5 ,))],
363
+ [helper .make_tensor_value_info ("value1" , TensorProto .FLOAT , (5 ,)),
364
+ helper .make_tensor_value_info ("mask" , TensorProto .BOOL , (5 ,)),
365
+ helper .make_tensor_value_info ("value2" , TensorProto .FLOAT , (5 ,))],
366
+ )
367
+
368
+ model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
369
+ self .run_merge_duplicated_nodes_compare (["value1" , "mask" , "value2" ],
370
+ {"X" : np .random .randn (5 ,).astype (np .float32 )},
371
+ model_proto ,
372
+ op_type = "Dropout" , remaining_op_num = 2 )
373
+
374
+ def test_duplicated_need_multiple_run (self ):
375
+ node00 = helper .make_node ('Log' , inputs = ["X" ], outputs = ["value00" ])
376
+ node01 = helper .make_node ('Log' , inputs = ["value00" ], outputs = ["value01" ])
377
+ node02 = helper .make_node ('Log' , inputs = ["value01" ], outputs = ["value02" ])
378
+
379
+ node10 = helper .make_node ('Log' , inputs = ["X" ], outputs = ["value10" ])
380
+ node11 = helper .make_node ('Log' , inputs = ["value10" ], outputs = ["value11" ])
381
+ node12 = helper .make_node ('Log' , inputs = ["value11" ], outputs = ["value12" ])
382
+
383
+ res = helper .make_node ('Add' , inputs = ["value02" , "value12" ], outputs = ["res" ])
384
+
385
+ graph = helper .make_graph (
386
+ [node00 , node01 , node02 , node10 , node11 , node12 , res ],
387
+ "test_duplicated_node_is_graph_output" ,
388
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (5 ,))],
389
+ [helper .make_tensor_value_info ("res" , TensorProto .FLOAT , (5 ,))],
390
+ )
391
+
392
+ model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
393
+ self .run_merge_duplicated_nodes_compare (["res" ], {"X" : np .random .randn (5 ,).astype (np .float32 )},
394
+ model_proto ,
395
+ op_type = "Log" , remaining_op_num = 3 )
336
396
# Merge Duplicated Nodes Optimizer Tests End
337
397
338
398
0 commit comments