@@ -368,6 +368,32 @@ def _make_loop(external_inputs, outputs):
368
368
self .run_transpose_compare (["Y" ], {"array" : np .random .randn (10 , 3 , 4 , 5 ).astype (np .float32 )},
369
369
model_proto , remaining_transpose_num = 0 )
370
370
371
+ def test_trans_with_sub (self ):
372
+ io_shape = [2 , 3 , 4 , 5 ]
373
+ const_shapes = [[2 , 4 , 5 , 3 ], [4 , 5 , 3 ], [5 , 3 ], [3 ]]
374
+ for trans_is_first_input in [True , False ]:
375
+ for const_shape in const_shapes :
376
+ node1 = helper .make_node ("Transpose" , ["X" ], ["Y" ], perm = [0 , 2 , 3 , 1 ], name = "trans_a" )
377
+ const_tensor = helper .make_tensor (name = 'const' , data_type = TensorProto .FLOAT , dims = const_shape ,
378
+ vals = np .random .randn (* const_shape ).flatten ().astype (np .float32 ))
379
+ node2 = helper .make_node ("Constant" , [], ["const" ], value = const_tensor , name = "const" )
380
+ if trans_is_first_input :
381
+ node3 = helper .make_node ("Sub" , ["Y" , "const" ], ["Z" ], name = "sub" )
382
+ else :
383
+ node3 = helper .make_node ("Sub" , ["const" , "Y" ], ["Z" ], name = "sub" )
384
+
385
+ node4 = helper .make_node ("Transpose" , ["Z" ], ["res" ], perm = [0 , 3 , 1 , 2 ], name = "trans_b" )
386
+ graph = helper .make_graph (
387
+ [node1 , node2 , node3 , node4 ],
388
+ "test_trans_with_sub" ,
389
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , io_shape )],
390
+ [helper .make_tensor_value_info ("res" , TensorProto .FLOAT , io_shape )],
391
+ )
392
+
393
+ model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
394
+ self .run_transpose_compare (["res" ], {"X" : np .random .randn (2 , 3 , 4 , 5 ).astype (np .float32 )},
395
+ model_proto , remaining_transpose_num = 0 )
396
+
371
397
def test_trans_output_as_graph_outputs (self ):
372
398
"""
373
399
If transpose's output is graph's output, don't optimize it.
0 commit comments