@@ -27,8 +27,10 @@ def run_and_compare(self, output_names_with_port, onnx_feed_dict, origin_proto,
27
27
remaining_op_num , debug = False , rtol = 1e-07 ):
28
28
utils .make_sure (op_type is not None , "op_type should be specified" )
29
29
utils .make_sure (remaining_op_num is not None , "remaining_op_num should be specified" )
30
+ utils .make_sure (self .config .is_onnxruntime_backend , "only onnxruntime is supported to test transpose optimizer" )
30
31
31
32
origin_model_path = self .save_onnx_model (origin_proto , onnx_feed_dict , postfix = "_origin" )
33
+ expected = self .run_onnxruntime (origin_model_path , onnx_feed_dict , output_names_with_port )
32
34
33
35
new_proto , new_graph = GraphUtil .optimize_model_proto (origin_proto , catch_errors = False , return_graph = True )
34
36
@@ -37,21 +39,16 @@ def run_and_compare(self, output_names_with_port, onnx_feed_dict, origin_proto,
37
39
new_model_path = self .save_onnx_model (new_proto , onnx_feed_dict , postfix = "_opt" )
38
40
current = GraphUtil .get_node_count_from_onnx_graph (new_proto .graph )
39
41
40
- self .assertTrue (current [op_type ] == remaining_op_num ,
41
- msg = "Expect " + str (remaining_op_num ) + " " + op_type + " ops left, but actually " + str (
42
- current [op_type ]) + " left" )
43
-
44
- if self .config .is_onnxruntime_backend :
45
- expected = self .run_onnxruntime (origin_model_path , onnx_feed_dict , output_names_with_port )
46
- actual = self .run_onnxruntime (new_model_path , onnx_feed_dict , output_names_with_port )
47
- else :
48
- raise ValueError ("only onnxruntime is supported to test transpose optimizer" )
42
+ actual = self .run_onnxruntime (new_model_path , onnx_feed_dict , output_names_with_port )
49
43
50
44
for expected_val , actual_val in zip (expected , actual ):
51
45
self .assertAllClose (expected_val , actual_val , rtol = rtol , atol = 1e-5 )
52
46
self .assertEqual (expected_val .dtype , actual_val .dtype )
53
47
self .assertEqual (expected_val .shape , actual_val .shape )
54
48
49
+ self .assertTrue (current [op_type ] == remaining_op_num ,
50
+ msg = "Expect " + str (remaining_op_num ) + " " + op_type + " ops left, but actually " + str (
51
+ current [op_type ]) + " left" )
55
52
self .assert_shapes_correct (new_graph , allow_missing = False , run_checker = True )
56
53
57
54
return new_proto
@@ -124,6 +121,7 @@ def test_transpose_with_concat(self, input_shape, perm, inner_perm):
124
121
self .run_transpose_compare (["res" ], feed_dict , model_proto , remaining_transpose_num = 1 )
125
122
126
123
@parameterized .expand ([
124
+ ((2 , 3 , 4 ), [2 , 0 , 1 ], [1 , 2 , 0 ]),
127
125
((2 , 3 , 4 , 5 ), [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
128
126
((2 , 3 , 4 , 5 , 6 ), [0 , 2 , 3 , 4 , 1 ], [0 , 4 , 1 , 2 , 3 ]),
129
127
])
@@ -176,6 +174,7 @@ def test_transpose_with_add2(self, input_shape1, input_shape2, perm_input, perm_
176
174
self .run_transpose_compare (["res" ], feed_dict , model_proto , remaining_transpose_num = 1 )
177
175
178
176
@parameterized .expand ([
177
+ ((2 , 3 , 4 ), [2 , 0 , 1 ], [1 , 2 , 0 ]),
179
178
((2 , 3 , 4 , 5 ), [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
180
179
((2 , 3 , 4 , 5 , 6 ), [0 , 2 , 3 , 4 , 1 ], [0 , 4 , 1 , 2 , 3 ]),
181
180
])
@@ -196,6 +195,7 @@ def test_transpose_relu(self, shape, perm_input, perm_output):
196
195
model_proto , remaining_transpose_num = 0 )
197
196
198
197
@parameterized .expand ([
198
+ ((2 , 3 , 4 ), [2 , 0 , 1 ], [1 , 2 , 0 ]),
199
199
((2 , 3 , 4 , 5 ), [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
200
200
((2 , 3 , 4 , 5 , 6 ), [0 , 2 , 3 , 4 , 1 ], [0 , 4 , 1 , 2 , 3 ]),
201
201
])
@@ -216,6 +216,7 @@ def test_transpose_leaky_relu(self, shape, perm_input, perm_output):
216
216
model_proto , remaining_transpose_num = 0 )
217
217
218
218
@parameterized .expand ([
219
+ ((2 , 3 , 4 ), [2 , 0 , 1 ], [1 , 2 , 0 ]),
219
220
((2 , 3 , 4 , 5 ), [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
220
221
((2 , 3 , 4 , 5 , 6 ), [0 , 2 , 3 , 4 , 1 ], [0 , 4 , 1 , 2 , 3 ]),
221
222
])
@@ -240,15 +241,16 @@ def test_transpose_quantize(self, shape, perm_input, perm_output):
240
241
model_proto , remaining_transpose_num = 0 )
241
242
242
243
@parameterized .expand ([
244
+ ((2 , 3 , 4 ), [0 , 2 , 1 ], [0 , 2 , 1 ]),
243
245
((2 , 3 , 4 , 5 ), [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
244
246
((2 , 3 , 4 , 5 , 6 ), [0 , 2 , 3 , 4 , 1 ], [0 , 4 , 1 , 2 , 3 ]),
245
247
])
246
248
@check_opset_min_version (13 , "QuantizeLinear with axis" )
247
249
def test_transpose_quantize_with_axis (self , shape , perm_input , perm_output ):
248
- scale = numpy_helper .from_array (np .array ([0.75 , 0.1 , 2.3 , 0.3 , 0.42 ], dtype = np .float32 ), name = 'scale' )
249
- zero_point = numpy_helper .from_array (np .array ([2 , 4 , 6 , 8 , 10 ], dtype = np .uint8 ), name = 'zero_point' )
250
+ scale = numpy_helper .from_array (np .array ([0.75 , 0.1 , 2.3 , 0.3 ], dtype = np .float32 ), name = 'scale' )
251
+ zero_point = numpy_helper .from_array (np .array ([2 , 4 , 6 , 8 ], dtype = np .uint8 ), name = 'zero_point' )
250
252
node1 = helper .make_node ("Transpose" , ["X" ], ["Y" ], perm = perm_input , name = "trans_1" )
251
- node2 = helper .make_node ("QuantizeLinear" , ["Y" , "scale" , "zero_point" ], ["Z" ], name = "quantize" , axis = 2 )
253
+ node2 = helper .make_node ("QuantizeLinear" , ["Y" , "scale" , "zero_point" ], ["Z" ], name = "quantize" , axis = 1 )
252
254
node3 = helper .make_node ("Transpose" , ["Z" ], ["Z1" ], perm = perm_output , name = "trans_2" )
253
255
254
256
graph = helper .make_graph (
@@ -264,6 +266,7 @@ def test_transpose_quantize_with_axis(self, shape, perm_input, perm_output):
264
266
model_proto , remaining_transpose_num = 0 )
265
267
266
268
@parameterized .expand ([
269
+ ((2 , 3 , 4 ), [2 , 0 , 1 ], [1 , 2 , 0 ]),
267
270
((2 , 3 , 4 , 5 ), [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
268
271
((2 , 3 , 4 , 5 , 6 ), [0 , 2 , 3 , 4 , 1 ], [0 , 4 , 1 , 2 , 3 ]),
269
272
])
@@ -288,15 +291,16 @@ def test_transpose_dequantize(self, shape, perm_input, perm_output):
288
291
model_proto , remaining_transpose_num = 0 )
289
292
290
293
@parameterized .expand ([
294
+ ((2 , 3 , 4 ), [0 , 2 , 1 ], [0 , 2 , 1 ]),
291
295
((2 , 3 , 4 , 5 ), [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
292
296
((2 , 3 , 4 , 5 , 6 ), [0 , 2 , 3 , 4 , 1 ], [0 , 4 , 1 , 2 , 3 ]),
293
297
])
294
298
@check_opset_min_version (13 , "DequantizeLinear with axis" )
295
299
def test_transpose_dequantize_with_axis (self , shape , perm_input , perm_output ):
296
- scale = numpy_helper .from_array (np .array ([0.75 , 0.1 , 2.3 , 0.3 , 0.42 ], dtype = np .float32 ), name = 'scale' )
297
- zero_point = numpy_helper .from_array (np .array ([2 , 4 , 6 , 8 , 10 ], dtype = np .uint8 ), name = 'zero_point' )
300
+ scale = numpy_helper .from_array (np .array ([0.75 , 0.1 , 2.3 , 0.3 ], dtype = np .float32 ), name = 'scale' )
301
+ zero_point = numpy_helper .from_array (np .array ([2 , 4 , 6 , 8 ], dtype = np .uint8 ), name = 'zero_point' )
298
302
node1 = helper .make_node ("Transpose" , ["X" ], ["Y" ], perm = perm_input , name = "trans_1" )
299
- node2 = helper .make_node ("DequantizeLinear" , ["Y" , "scale" , "zero_point" ], ["Z" ], name = "dequantize" , axis = 2 )
303
+ node2 = helper .make_node ("DequantizeLinear" , ["Y" , "scale" , "zero_point" ], ["Z" ], name = "dequantize" , axis = 1 )
300
304
node3 = helper .make_node ("Transpose" , ["Z" ], ["Z1" ], perm = perm_output , name = "trans_2" )
301
305
302
306
graph = helper .make_graph (
@@ -312,6 +316,7 @@ def test_transpose_dequantize_with_axis(self, shape, perm_input, perm_output):
312
316
model_proto , remaining_transpose_num = 0 )
313
317
314
318
@parameterized .expand ([
319
+ ([2 , 3 , 4 ], [1 , 2 , 1 ], [1 ], [0 , 2 , 1 ], [0 , 2 , 1 ]),
315
320
([2 , 3 , 4 , 5 ], [1 , 2 , 1 , 2 ], [1 ], [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
316
321
([2 , 3 , 4 , 5 ], [1 , 2 , 1 , 2 ], [1 , 2 ], [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
317
322
([2 , 3 , 4 , 5 ], [1 , 2 , 1 , 2 ], [0 , 1 , 2 , 3 ], [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
@@ -351,6 +356,7 @@ def test_transpose_slice(self, input_shape, slice_size, axes, perm_input, perm_o
351
356
model_proto , remaining_transpose_num = 0 )
352
357
353
358
@parameterized .expand ([
359
+ ([2 , 3 , 4 ], [1 , 2 , 1 ], [1 ], [0 , 2 , 1 ], [0 , 2 , 1 ]),
354
360
([2 , 3 , 4 , 5 ], [1 , 2 , 1 , 2 ], [1 ], [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
355
361
([2 , 3 , 4 , 5 ], [1 , 2 , 1 , 2 ], [1 , 2 ], [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
356
362
([2 , 3 , 4 , 5 ], [1 , 2 , 1 , 2 ], [0 , 1 , 2 , 3 ], [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
@@ -390,6 +396,7 @@ def test_transpose_slice_opset_10(self, input_shape, slice_size, axes, perm_inpu
390
396
model_proto , remaining_transpose_num = 0 )
391
397
392
398
@parameterized .expand ([
399
+ ((2 , 3 , 4 ), (4 , 2 , 3 ), (2 , 0 , 1 ), (1 , 2 , 0 )),
393
400
((2 , 3 , 4 , 5 ), (2 , 4 , 5 , 3 ), [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
394
401
((2 , 3 , 4 , 5 , 6 ), (2 , 4 , 5 , 6 , 3 ), [0 , 2 , 3 , 4 , 1 ], [0 , 4 , 1 , 2 , 3 ]),
395
402
])
@@ -513,6 +520,7 @@ def test_transpose_merge(self, input_shape1, input_shape2, perm):
513
520
514
521
515
522
@parameterized .expand ([
523
+ ((2 , 3 , 4 ), [2 , 0 , 1 ], [1 , 2 , 0 ]),
516
524
((2 , 3 , 4 , 5 ), [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
517
525
((2 , 3 , 4 , 5 , 6 ), [0 , 2 , 3 , 4 , 1 ], [0 , 4 , 1 , 2 , 3 ]),
518
526
])
@@ -533,6 +541,7 @@ def test_transpose_mul_as_square(self, shape, perm_input, perm_output):
533
541
model_proto , remaining_transpose_num = 0 )
534
542
535
543
@parameterized .expand ([
544
+ ((2 , 3 , 4 ), [2 , 0 , 1 ], [1 , 2 , 0 ]),
536
545
((2 , 3 , 4 , 5 ), [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
537
546
((2 , 3 , 4 , 5 , 6 ), [0 , 2 , 3 , 4 , 1 ], [0 , 4 , 1 , 2 , 3 ]),
538
547
])
@@ -555,6 +564,7 @@ def test_transpose_mul_broadcastable_const(self, shape, perm_input, perm_output)
555
564
model_proto , remaining_transpose_num = 0 )
556
565
557
566
@parameterized .expand ([
567
+ ((2 , 3 , 4 ), [2 , 0 , 1 ]),
558
568
((2 , 3 , 4 , 5 ), [0 , 2 , 3 , 1 ]),
559
569
((2 , 3 , 4 , 5 , 6 ), [0 , 2 , 3 , 4 , 1 ]),
560
570
])
@@ -574,6 +584,7 @@ def test_transpose_with_shape(self, shape, perm):
574
584
model_proto , remaining_transpose_num = 0 )
575
585
576
586
@parameterized .expand ([
587
+ ((2 , 3 , 4 ), (4 , 2 , 3 ), [2 , 0 , 1 ]),
577
588
((2 , 3 , 4 , 5 ), (2 , 4 , 5 , 3 ), [0 , 2 , 3 , 1 ]),
578
589
((2 , 3 , 4 , 5 , 6 ), (2 , 4 , 5 , 6 , 3 ), [0 , 2 , 3 , 4 , 1 ]),
579
590
])
@@ -593,6 +604,7 @@ def test_transpose_with_identity(self, input_shape, output_shape, perm):
593
604
model_proto , remaining_transpose_num = 1 )
594
605
595
606
@parameterized .expand ([
607
+ ((2 , 3 , 4 ), [2 , 0 , 1 ], [1 , 2 , 0 ]),
596
608
((2 , 3 , 4 , 5 ), [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
597
609
((2 , 3 , 4 , 5 , 6 ), [0 , 2 , 3 , 4 , 1 ], [0 , 4 , 1 , 2 , 3 ]),
598
610
])
@@ -613,6 +625,7 @@ def test_transpose_sqrt(self, shape, perm_input, perm_output):
613
625
model_proto , remaining_transpose_num = 0 )
614
626
615
627
@parameterized .expand ([
628
+ ((1 , 3 , 4 ), [4 , 3 ], [0 , 2 , 1 ], [1 , 0 ]),
616
629
((1 , 3 , 4 , 5 ), (4 , 5 , 3 ), [0 , 2 , 3 , 1 ], [1 , 2 , 0 ]),
617
630
((1 , 3 , 4 , 5 , 6 ), (4 , 5 , 6 , 3 ), [0 , 2 , 3 , 4 , 1 ], [1 , 2 , 3 , 0 ]),
618
631
])
@@ -635,17 +648,18 @@ def test_transpose_with_squeeze1(self, input_shape, output_shape, perm, expected
635
648
self .check_transpose_perm (model_after_opt , expected_perm )
636
649
637
650
@parameterized .expand ([
638
- ((1 , 3 , 4 , 5 ), (1 , 1 , 4 , 5 , 1 , 3 , 1 ), [0 , 2 , 3 , 1 ], [0 , 1 , 4 , 5 , 2 , 3 , 6 ]),
639
- ((1 , 3 , 4 , 5 , 6 ), (1 , 1 , 4 , 5 , 1 , 6 , 1 , 3 ), [0 , 2 , 3 , 4 , 1 ], [0 , 1 , 4 , 5 , 6 , 7 , 2 , 3 ]),
651
+ ((1 , 3 , 4 ), (1 , 4 , 1 , 3 , 1 , 1 ), [2 , 0 , 1 ], [0 , 4 , 5 ], [2 , 3 , 0 , 1 , 4 , 5 ]),
652
+ ((1 , 3 , 4 , 5 ), (1 , 1 , 4 , 5 , 1 , 3 , 1 ), [0 , 2 , 3 , 1 ], [0 , 4 , 6 ], [0 , 1 , 4 , 5 , 2 , 3 , 6 ]),
653
+ ((1 , 3 , 4 , 5 , 6 ), (1 , 1 , 4 , 5 , 1 , 6 , 1 , 3 ), [0 , 2 , 3 , 4 , 1 ], [0 , 4 , 6 ], [0 , 1 , 4 , 5 , 6 , 7 , 2 , 3 ]),
640
654
])
641
- def test_transpose_with_unsqueeze (self , input_shape , output_shape , perm , expected_perm ):
655
+ def test_transpose_with_unsqueeze (self , input_shape , output_shape , perm , axes_val , expected_perm ):
642
656
# unsqueeze the first dim
643
657
node1 = helper .make_node ("Transpose" , ["X" ], ["Y" ], perm = perm , name = "trans" )
644
658
if self .config .opset <= 12 :
645
- node2 = helper .make_node ("Unsqueeze" , ["Y" ], ["Z" ], name = "unsqueeze" , axes = [ 0 , 4 , 6 ] )
659
+ node2 = helper .make_node ("Unsqueeze" , ["Y" ], ["Z" ], name = "unsqueeze" , axes = axes_val )
646
660
nodes = [node1 , node2 ]
647
661
else :
648
- axes = self ._make_onnx_const (np .array ([ 0 , 4 , 6 ] , dtype = np .int64 ), "axes" )
662
+ axes = self ._make_onnx_const (np .array (axes_val , dtype = np .int64 ), "axes" )
649
663
node2 = helper .make_node ("Unsqueeze" , ["Y" , "axes" ], ["Z" ], name = "unsqueeze" )
650
664
nodes = [axes , node1 , node2 ]
651
665
@@ -662,6 +676,7 @@ def test_transpose_with_unsqueeze(self, input_shape, output_shape, perm, expecte
662
676
self .check_transpose_perm (model_after_opt , expected_perm )
663
677
664
678
@parameterized .expand ([
679
+ ((1 , 3 , 4 ), [4 , 3 ], [0 , 2 , 1 ], [1 , 0 ]),
665
680
((1 , 3 , 4 , 5 ), (4 , 5 , 3 ), [0 , 2 , 3 , 1 ], [1 , 2 , 0 ]),
666
681
((1 , 3 , 4 , 5 , 6 ), (4 , 5 , 6 , 3 ), [0 , 2 , 3 , 4 , 1 ], [1 , 2 , 3 , 0 ]),
667
682
])
@@ -816,6 +831,7 @@ def test_transpose_with_squeeze4_13(self, input_shape, output_shape, perm):
816
831
model_proto , remaining_transpose_num = 0 )
817
832
818
833
@parameterized .expand ([
834
+ ((10 , 3 , 4 ), [0 , 2 , 1 ], [0 , 2 , 1 ]),
819
835
((10 , 3 , 4 , 5 ), [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
820
836
((10 , 3 , 4 , 5 , 6 ), [0 , 2 , 3 , 4 , 1 ], [0 , 4 , 1 , 2 , 3 ]),
821
837
])
@@ -880,6 +896,7 @@ def _make_loop(external_inputs, outputs):
880
896
model_proto , remaining_transpose_num = 0 )
881
897
882
898
@parameterized .expand ([
899
+ ((2 , 3 , 4 ), [4 , 2 , 3 ], [2 , 0 , 1 ], [1 , 2 , 0 ]),
883
900
((2 , 3 , 4 , 5 ), [2 , 4 , 5 , 3 ], [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
884
901
((2 , 3 , 4 , 5 , 6 ), [2 , 4 , 5 , 6 , 3 ], [0 , 2 , 3 , 4 , 1 ], [0 , 4 , 1 , 2 , 3 ]),
885
902
])
@@ -965,6 +982,7 @@ def test_transpose_add_with_input_non_const(self, input_shape1, input_shape2, pe
965
982
model_proto , remaining_transpose_num = 0 )
966
983
967
984
@parameterized .expand ([
985
+ ((2 , 3 , 4 ), [4 , 2 , 3 ], [2 , 0 , 1 ], [1 , 2 , 0 ]),
968
986
((1 , 1 , 3 , 3 ), (1 , 3 , 3 , 1 ), [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
969
987
((1 , 1 , 3 , 3 , 3 ), (1 , 3 , 3 , 3 , 1 ), [0 , 2 , 3 , 4 , 1 ], [0 , 4 , 1 , 2 , 3 ]),
970
988
])
@@ -1050,6 +1068,7 @@ def test_transpose_add_with_conv_2(self, input_shape, weights_shape, output_shap
1050
1068
model_proto , remaining_transpose_num = 0 )
1051
1069
1052
1070
@parameterized .expand ([
1071
+ ((3 , 4 , 5 ), (8 , 4 , 6 ), [1 , 3 , 0 , 0 , 2 , 0 ], [2 , 0 , 1 ], [1 , 2 , 0 ]),
1053
1072
((1 , 3 , 4 , 5 ), (2 , 6 , 4 , 8 ), [1 , 0 , 1 , 3 , 0 , 0 , 2 , 0 ], [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
1054
1073
((1 , 3 , 4 , 5 , 6 ), (2 , 5 , 6 , 8 , 10 ), [1 , 0 , 1 , 3 , 1 , 0 , 2 , 2 , 1 , 1 ], [0 , 2 , 3 , 4 , 1 ], [0 , 4 , 1 , 2 , 3 ]),
1055
1074
])
@@ -1071,6 +1090,7 @@ def test_transpose_pad(self, input_shape, output_shape, pads, perm_input, perm_o
1071
1090
model_proto , remaining_transpose_num = 0 )
1072
1091
1073
1092
@parameterized .expand ([
1093
+ ((3 , 4 , 5 ), (8 , 4 , 6 ), [1 , 3 , 0 , 0 , 2 , 0 ], [2 , 0 , 1 ], [1 , 2 , 0 ]),
1074
1094
((1 , 3 , 4 , 5 ), (2 , 6 , 4 , 8 ), [1 , 0 , 1 , 3 , 0 , 0 , 2 , 0 ], [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
1075
1095
((1 , 3 , 4 , 5 , 6 ), (2 , 5 , 6 , 8 , 10 ), [1 , 0 , 1 , 3 , 1 , 0 , 2 , 2 , 1 , 1 ], [0 , 2 , 3 , 4 , 1 ], [0 , 4 , 1 , 2 , 3 ]),
1076
1096
])
@@ -1097,6 +1117,7 @@ def test_transpose_pad11(self, input_shape, output_shape, pads, perm_input, perm
1097
1117
model_proto , remaining_transpose_num = 0 )
1098
1118
1099
1119
@parameterized .expand ([
1120
+ ((3 , 4 , 5 ), (8 , 4 , 6 ), [1 , 3 , 0 , 0 , 2 , 0 ], [2 , 0 , 1 ], [1 , 2 , 0 ]),
1100
1121
((1 , 3 , 4 , 5 ), (2 , 6 , 4 , 8 ), [1 , 0 , 1 , 3 , 0 , 0 , 2 , 0 ], [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
1101
1122
((1 , 3 , 4 , 5 , 6 ), (2 , 5 , 6 , 8 , 10 ), [1 , 0 , 1 , 3 , 1 , 0 , 2 , 2 , 1 , 1 ], [0 , 2 , 3 , 4 , 1 ], [0 , 4 , 1 , 2 , 3 ]),
1102
1123
])
@@ -1125,6 +1146,7 @@ def test_transpose_pad11_non_const_pads(self, input_shape, output_shape, pads, p
1125
1146
model_proto , remaining_transpose_num = 0 )
1126
1147
1127
1148
@parameterized .expand ([
1149
+ ((2 , 3 , 4 ), [2 , 0 , 1 ], [1 , 2 , 0 ]),
1128
1150
((2 , 3 , 4 , 5 ), [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
1129
1151
((2 , 3 , 4 , 5 , 6 ), [0 , 2 , 3 , 4 , 1 ], [0 , 4 , 1 , 2 , 3 ]),
1130
1152
])
@@ -1145,6 +1167,7 @@ def test_transpose_reciprocal(self, shape, perm_input, perm_output):
1145
1167
model_proto , remaining_transpose_num = 0 )
1146
1168
1147
1169
@parameterized .expand ([
1170
+ ((3 , 4 , 5 ), (3 , 4 , 1 ), [0 , 2 , 1 ], [0 , 2 , 1 ]),
1148
1171
((1 , 3 , 4 , 5 ), (1 , 3 , 1 , 1 ), [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
1149
1172
((1 , 3 , 4 , 5 , 6 ), (1 , 3 , 1 , 1 , 1 ), [0 , 2 , 3 , 4 , 1 ], [0 , 4 , 1 , 2 , 3 ]),
1150
1173
])
@@ -1166,6 +1189,7 @@ def test_transpose_reducemean(self, input_shape, output_shape, perm_input, perm_
1166
1189
model_proto , remaining_transpose_num = 0 )
1167
1190
1168
1191
@parameterized .expand ([
1192
+ ((3 , 4 , 5 ), (3 , 4 , 1 ), [1 ], [0 , 2 , 1 ], [0 , 2 , 1 ]),
1169
1193
((1 , 3 , 4 , 5 ), (1 , 3 , 4 , 1 ), [2 ], [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
1170
1194
((1 , 3 , 4 , 5 ), (1 , 3 , 1 , 1 ), [1 , 2 ], [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
1171
1195
((1 , 3 , 4 , 5 ), (1 , 1 , 1 , 1 ), [0 , 1 , 2 , 3 ], [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
@@ -1258,6 +1282,7 @@ def test_transpose_tile(self):
1258
1282
model_proto , remaining_transpose_num = 0 )
1259
1283
1260
1284
@parameterized .expand ([
1285
+ ((3 , 4 , 5 ), (3 , 4 , 1 ), [1 ], [0 , 2 , 1 ], [0 , 2 , 1 ]),
1261
1286
((1 , 3 , 4 , 5 ), (1 , 3 , 4 , 1 ), [2 ], [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
1262
1287
((1 , 3 , 4 , 5 ), (1 , 3 , 1 , 1 ), [1 , 2 ], [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
1263
1288
((1 , 3 , 4 , 5 ), (1 , 1 , 1 , 1 ), [0 , 1 , 2 , 3 ], [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
@@ -1286,6 +1311,7 @@ def test_transpose_reducesum_opset_13(self, input_shape, output_shape, axes, per
1286
1311
model_proto , remaining_transpose_num = 0 )
1287
1312
1288
1313
@parameterized .expand ([
1314
+ ((2 , 3 , 4 ), (4 , 2 , 3 ), [2 , 0 , 1 ]),
1289
1315
((2 , 3 , 4 , 5 ), (2 , 4 , 5 , 3 ), [0 , 2 , 3 , 1 ]),
1290
1316
((2 , 3 , 4 , 5 , 6 ), (2 , 4 , 5 , 6 , 3 ), [0 , 2 , 3 , 4 , 1 ]),
1291
1317
])
@@ -1362,6 +1388,7 @@ def test_trans_can_be_replaced_with_reshape2(self, input_shape_np, input_shape,
1362
1388
model_proto , remaining_transpose_num = 0 )
1363
1389
1364
1390
@parameterized .expand ([
1391
+ ((1 , 6 , 8 ), [2 , 0 , 1 ], [1 , 2 , 0 ]),
1365
1392
((1 , 6 , 8 , 9 ), [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
1366
1393
((1 , 6 , 8 , 9 , 2 ), [0 , 2 , 3 , 4 , 1 ], [0 , 4 , 1 , 2 , 3 ]),
1367
1394
])
@@ -1388,6 +1415,7 @@ def test_two_transposes_switch_with_mul(self, shape, perm_input, perm_output):
1388
1415
model_proto , remaining_transpose_num = 0 )
1389
1416
1390
1417
@parameterized .expand ([
1418
+ ((1 , 6 , 8 ), (8 , 1 , 6 ), [2 , 0 , 1 ], [1 , 2 , 0 ]),
1391
1419
((1 , 6 , 8 , 9 ), (1 , 8 , 9 , 6 ), [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
1392
1420
((1 , 6 , 8 , 9 , 2 ), (1 , 8 , 9 , 2 , 6 ), [0 , 2 , 3 , 4 , 1 ], [0 , 4 , 1 , 2 , 3 ]),
1393
1421
])
0 commit comments