@@ -445,6 +445,91 @@ def test_trans_with_sub_input_non_const(self):
445
445
"non_const" : np .random .randn (* non_const_shape ).astype (np .float32 )},
446
446
model_proto , remaining_transpose_num = 1 )
447
447
448
+ def test_transpose_add_with_input_non_const (self ):
449
+
450
+ node0 = helper .make_node ("Transpose" , ["X" ], ["Y" ], perm = [0 , 2 , 3 , 1 ], name = "trans_1" )
451
+ node1 = helper .make_node ("Add" , ["Y" , "A" ], ["Z" ], name = "add" )
452
+ node2 = helper .make_node ("Transpose" , ["Z" ], ["res" ], perm = [0 , 3 , 1 , 2 ], name = "trans_2" )
453
+
454
+ graph = helper .make_graph (
455
+ [node0 , node1 , node2 ],
456
+ "transpose-add-test-input-non-const" ,
457
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (1 , 1 , 3 , 3 )),
458
+ helper .make_tensor_value_info ("A" , TensorProto .FLOAT , (1 , 3 , 3 , 1 ))],
459
+ [helper .make_tensor_value_info ("res" , TensorProto .FLOAT , (1 , 1 , 3 , 3 ))],
460
+ )
461
+
462
+ model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
463
+ self .run_transpose_compare (["res" ], {"X" : np .random .randn (1 , 1 , 3 , 3 ).astype (np .float32 ),
464
+ "A" : np .random .randn (1 , 3 , 3 , 1 ).astype (np .float32 )},
465
+ model_proto , remaining_transpose_num = 0 )
466
+
467
+ def test_transpose_add_with_input_const (self ):
468
+ const_1_val = np .random .randn (1 , 3 , 3 , 1 ).astype (np .float32 ).reshape (9 ).tolist ()
469
+ const_1 = helper .make_tensor ("const_1" , TensorProto .FLOAT , (1 , 3 , 3 , 1 ), const_1_val )
470
+ const_1_node = helper .make_node ("Constant" , [], ["const_1" ], value = const_1 , name = "const_1" )
471
+
472
+ node0 = helper .make_node ("Transpose" , ["X" ], ["Y" ], perm = [0 , 2 , 3 , 1 ], name = "trans_1" )
473
+ node1 = helper .make_node ("Add" , ["Y" , "const_1" ], ["Z" ], name = "add" )
474
+ node2 = helper .make_node ("Transpose" , ["Z" ], ["res" ], perm = [0 , 3 , 1 , 2 ], name = "trans_2" )
475
+
476
+ graph = helper .make_graph (
477
+ [const_1_node , node0 , node1 , node2 ],
478
+ "transpose-add-test-input-const" ,
479
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (1 , 1 , 3 , 3 ))],
480
+ [helper .make_tensor_value_info ("res" , TensorProto .FLOAT , (1 , 1 , 3 , 3 ))],
481
+ )
482
+
483
+ model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
484
+ self .run_transpose_compare (["res" ], {"X" : np .random .randn (1 , 1 , 3 , 3 ).astype (np .float32 )},
485
+ model_proto , remaining_transpose_num = 0 )
486
+
487
+ def test_transpose_add_with_conv_1 (self ):
488
+ const_b_val = np .random .randn (1 , 1 , 1 , 16 ).astype (np .float32 ).reshape (16 ).tolist ()
489
+ const_b = helper .make_tensor ("const_b" , TensorProto .FLOAT , (1 , 1 , 1 , 16 ), const_b_val )
490
+ const_b_node = helper .make_node ("Constant" , [], ["const_b" ], value = const_b , name = "const_b" )
491
+
492
+ node0 = helper .make_node ("Conv" , ["x" , "W" ], ["X" ], name = "conv" , pads = [0 , 0 , 0 , 0 ])
493
+ node1 = helper .make_node ("Transpose" , ["X" ], ["Y" ], perm = [0 , 2 , 3 , 1 ], name = "trans_1" )
494
+ node2 = helper .make_node ("Add" , ["Y" , "const_b" ], ["Z" ], name = "add" )
495
+ node3 = helper .make_node ("Transpose" , ["Z" ], ["res" ], perm = [0 , 3 , 1 , 2 ], name = "trans_2" )
496
+
497
+ graph = helper .make_graph (
498
+ [ const_b_node , node0 , node1 , node2 , node3 ],
499
+ "transpose-add-test-with-conv-1" ,
500
+ [helper .make_tensor_value_info ("x" , TensorProto .FLOAT , (1 , 5 , 3 , 3 )),
501
+ helper .make_tensor_value_info ("W" , TensorProto .FLOAT , (16 , 5 , 3 , 3 ))],
502
+ [helper .make_tensor_value_info ("res" , TensorProto .FLOAT , (1 , 16 , 1 , 1 ))],
503
+ )
504
+
505
+ model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
506
+ self .run_transpose_compare (["res" ], {"x" : np .random .randn (1 , 5 , 3 , 3 ).astype (np .float32 ),
507
+ "W" : np .random .randn (16 , 5 , 3 , 3 ).astype (np .float32 )},
508
+ model_proto , remaining_transpose_num = 0 )
509
+
510
+ def test_transpose_add_with_conv_2 (self ):
511
+ const_b_val = np .random .randn (1 , 3 , 3 , 1 ).astype (np .float32 ).reshape (9 ).tolist ()
512
+ const_b = helper .make_tensor ("const_b" , TensorProto .FLOAT , (1 , 3 , 3 , 1 ), const_b_val )
513
+ const_b_node = helper .make_node ("Constant" , [], ["const_b" ], value = const_b , name = "const_b" )
514
+
515
+ node0 = helper .make_node ("Conv" , ["x" , "W" ], ["X" ], name = "conv" , pads = [0 , 0 , 0 , 0 ])
516
+ node1 = helper .make_node ("Transpose" , ["X" ], ["Y" ], perm = [0 , 2 , 3 , 1 ], name = "trans_1" )
517
+ node2 = helper .make_node ("Add" , ["Y" , "const_b" ], ["Z" ], name = "add" )
518
+ node3 = helper .make_node ("Transpose" , ["Z" ], ["res" ], perm = [0 , 3 , 1 , 2 ], name = "trans_2" )
519
+
520
+ graph = helper .make_graph (
521
+ [const_b_node , node0 , node1 , node2 , node3 ],
522
+ "transpose-add-test-with-conv-2" ,
523
+ [helper .make_tensor_value_info ("x" , TensorProto .FLOAT , (1 , 1 , 5 , 5 )),
524
+ helper .make_tensor_value_info ("W" , TensorProto .FLOAT , (1 , 1 , 3 , 3 ))],
525
+ [helper .make_tensor_value_info ("res" , TensorProto .FLOAT , (1 , 1 , 3 , 3 ))],
526
+ )
527
+
528
+ model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
529
+ self .run_transpose_compare (["res" ], {"x" : np .random .randn (1 , 1 , 5 , 5 ).astype (np .float32 ),
530
+ "W" : np .random .randn (1 , 1 , 3 , 3 ).astype (np .float32 )},
531
+ model_proto , remaining_transpose_num = 0 )
532
+
448
533
def test_trans_output_as_graph_outputs (self ):
449
534
"""
450
535
If transpose's output is graph's output, don't optimize it.
0 commit comments