Skip to content

Commit ac4c68d

Browse files
Tom/enable transpose any rank (#1453)
* Check shapes and dtypes in optimizer tests Signed-off-by: Tom Wildenhain <[email protected]> * Enable transpose optimization of any rank Signed-off-by: Tom Wildenhain <[email protected]> * Fix shape bug in transpose opt Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 059afaf commit ac4c68d

File tree

3 files changed

+79
-34
lines changed

3 files changed

+79
-34
lines changed

tests/run_pretrained_models.yaml

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@ mobilebert_tflite:
552552

553553
palm_detection_tflite:
554554
tf_min_version: 2.1
555-
disabled: false # Converts but produces incorrect results. Working on fixing it.
555+
disabled: false
556556
url: https://github.com/google/mediapipe/raw/master/mediapipe/modules/palm_detection/palm_detection.tflite
557557
model: "palm_detection.tflite"
558558
model_type: tflite
@@ -568,3 +568,21 @@ palm_detection_tflite:
568568
- classificators
569569
rtol: 0.02
570570
atol: 0.0005
571+
572+
melgan_tflite: # TFLite model with FlexOps and rank-3 transposes
573+
tf_min_version: 2.1
574+
disabled: false
575+
url: https://tfhub.dev/tulasiram58827/lite-model/melgan/dr/1?lite-format=tflite
576+
model: "melgan.tflite"
577+
model_type: tflite
578+
input_get: get_zeros
579+
opset_constraints:
580+
"onnx":
581+
"min": 11
582+
dequantize: true
583+
inputs:
584+
"mels": [1, 100, 80]
585+
outputs:
586+
- Identity
587+
rtol: 0.02
588+
atol: 0.0005

tests/test_optimizers.py

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@ def run_and_compare(self, output_names_with_port, onnx_feed_dict, origin_proto,
2727
remaining_op_num, debug=False, rtol=1e-07):
2828
utils.make_sure(op_type is not None, "op_type should be specified")
2929
utils.make_sure(remaining_op_num is not None, "remaining_op_num should be specified")
30+
utils.make_sure(self.config.is_onnxruntime_backend, "only onnxruntime is supported to test transpose optimizer")
3031

3132
origin_model_path = self.save_onnx_model(origin_proto, onnx_feed_dict, postfix="_origin")
33+
expected = self.run_onnxruntime(origin_model_path, onnx_feed_dict, output_names_with_port)
3234

3335
new_proto, new_graph = GraphUtil.optimize_model_proto(origin_proto, catch_errors=False, return_graph=True)
3436

@@ -37,21 +39,16 @@ def run_and_compare(self, output_names_with_port, onnx_feed_dict, origin_proto,
3739
new_model_path = self.save_onnx_model(new_proto, onnx_feed_dict, postfix="_opt")
3840
current = GraphUtil.get_node_count_from_onnx_graph(new_proto.graph)
3941

40-
self.assertTrue(current[op_type] == remaining_op_num,
41-
msg="Expect " + str(remaining_op_num) + " " + op_type + " ops left, but actually " + str(
42-
current[op_type]) + " left")
43-
44-
if self.config.is_onnxruntime_backend:
45-
expected = self.run_onnxruntime(origin_model_path, onnx_feed_dict, output_names_with_port)
46-
actual = self.run_onnxruntime(new_model_path, onnx_feed_dict, output_names_with_port)
47-
else:
48-
raise ValueError("only onnxruntime is supported to test transpose optimizer")
42+
actual = self.run_onnxruntime(new_model_path, onnx_feed_dict, output_names_with_port)
4943

5044
for expected_val, actual_val in zip(expected, actual):
5145
self.assertAllClose(expected_val, actual_val, rtol=rtol, atol=1e-5)
5246
self.assertEqual(expected_val.dtype, actual_val.dtype)
5347
self.assertEqual(expected_val.shape, actual_val.shape)
5448

49+
self.assertTrue(current[op_type] == remaining_op_num,
50+
msg="Expect " + str(remaining_op_num) + " " + op_type + " ops left, but actually " + str(
51+
current[op_type]) + " left")
5552
self.assert_shapes_correct(new_graph, allow_missing=False, run_checker=True)
5653

5754
return new_proto
@@ -124,6 +121,7 @@ def test_transpose_with_concat(self, input_shape, perm, inner_perm):
124121
self.run_transpose_compare(["res"], feed_dict, model_proto, remaining_transpose_num=1)
125122

126123
@parameterized.expand([
124+
((2, 3, 4), [2, 0, 1], [1, 2, 0]),
127125
((2, 3, 4, 5), [0, 2, 3, 1], [0, 3, 1, 2]),
128126
((2, 3, 4, 5, 6), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
129127
])
@@ -176,6 +174,7 @@ def test_transpose_with_add2(self, input_shape1, input_shape2, perm_input, perm_
176174
self.run_transpose_compare(["res"], feed_dict, model_proto, remaining_transpose_num=1)
177175

178176
@parameterized.expand([
177+
((2, 3, 4), [2, 0, 1], [1, 2, 0]),
179178
((2, 3, 4, 5), [0, 2, 3, 1], [0, 3, 1, 2]),
180179
((2, 3, 4, 5, 6), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
181180
])
@@ -196,6 +195,7 @@ def test_transpose_relu(self, shape, perm_input, perm_output):
196195
model_proto, remaining_transpose_num=0)
197196

198197
@parameterized.expand([
198+
((2, 3, 4), [2, 0, 1], [1, 2, 0]),
199199
((2, 3, 4, 5), [0, 2, 3, 1], [0, 3, 1, 2]),
200200
((2, 3, 4, 5, 6), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
201201
])
@@ -216,6 +216,7 @@ def test_transpose_leaky_relu(self, shape, perm_input, perm_output):
216216
model_proto, remaining_transpose_num=0)
217217

218218
@parameterized.expand([
219+
((2, 3, 4), [2, 0, 1], [1, 2, 0]),
219220
((2, 3, 4, 5), [0, 2, 3, 1], [0, 3, 1, 2]),
220221
((2, 3, 4, 5, 6), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
221222
])
@@ -240,15 +241,16 @@ def test_transpose_quantize(self, shape, perm_input, perm_output):
240241
model_proto, remaining_transpose_num=0)
241242

242243
@parameterized.expand([
244+
((2, 3, 4), [0, 2, 1], [0, 2, 1]),
243245
((2, 3, 4, 5), [0, 2, 3, 1], [0, 3, 1, 2]),
244246
((2, 3, 4, 5, 6), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
245247
])
246248
@check_opset_min_version(13, "QuantizeLinear with axis")
247249
def test_transpose_quantize_with_axis(self, shape, perm_input, perm_output):
248-
scale = numpy_helper.from_array(np.array([0.75, 0.1, 2.3, 0.3, 0.42], dtype=np.float32), name='scale')
249-
zero_point = numpy_helper.from_array(np.array([2, 4, 6, 8, 10], dtype=np.uint8), name='zero_point')
250+
scale = numpy_helper.from_array(np.array([0.75, 0.1, 2.3, 0.3], dtype=np.float32), name='scale')
251+
zero_point = numpy_helper.from_array(np.array([2, 4, 6, 8], dtype=np.uint8), name='zero_point')
250252
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm_input, name="trans_1")
251-
node2 = helper.make_node("QuantizeLinear", ["Y", "scale", "zero_point"], ["Z"], name="quantize", axis=2)
253+
node2 = helper.make_node("QuantizeLinear", ["Y", "scale", "zero_point"], ["Z"], name="quantize", axis=1)
252254
node3 = helper.make_node("Transpose", ["Z"], ["Z1"], perm=perm_output, name="trans_2")
253255

254256
graph = helper.make_graph(
@@ -264,6 +266,7 @@ def test_transpose_quantize_with_axis(self, shape, perm_input, perm_output):
264266
model_proto, remaining_transpose_num=0)
265267

266268
@parameterized.expand([
269+
((2, 3, 4), [2, 0, 1], [1, 2, 0]),
267270
((2, 3, 4, 5), [0, 2, 3, 1], [0, 3, 1, 2]),
268271
((2, 3, 4, 5, 6), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
269272
])
@@ -288,15 +291,16 @@ def test_transpose_dequantize(self, shape, perm_input, perm_output):
288291
model_proto, remaining_transpose_num=0)
289292

290293
@parameterized.expand([
294+
((2, 3, 4), [0, 2, 1], [0, 2, 1]),
291295
((2, 3, 4, 5), [0, 2, 3, 1], [0, 3, 1, 2]),
292296
((2, 3, 4, 5, 6), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
293297
])
294298
@check_opset_min_version(13, "DequantizeLinear with axis")
295299
def test_transpose_dequantize_with_axis(self, shape, perm_input, perm_output):
296-
scale = numpy_helper.from_array(np.array([0.75, 0.1, 2.3, 0.3, 0.42], dtype=np.float32), name='scale')
297-
zero_point = numpy_helper.from_array(np.array([2, 4, 6, 8, 10], dtype=np.uint8), name='zero_point')
300+
scale = numpy_helper.from_array(np.array([0.75, 0.1, 2.3, 0.3], dtype=np.float32), name='scale')
301+
zero_point = numpy_helper.from_array(np.array([2, 4, 6, 8], dtype=np.uint8), name='zero_point')
298302
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm_input, name="trans_1")
299-
node2 = helper.make_node("DequantizeLinear", ["Y", "scale", "zero_point"], ["Z"], name="dequantize", axis=2)
303+
node2 = helper.make_node("DequantizeLinear", ["Y", "scale", "zero_point"], ["Z"], name="dequantize", axis=1)
300304
node3 = helper.make_node("Transpose", ["Z"], ["Z1"], perm=perm_output, name="trans_2")
301305

302306
graph = helper.make_graph(
@@ -312,6 +316,7 @@ def test_transpose_dequantize_with_axis(self, shape, perm_input, perm_output):
312316
model_proto, remaining_transpose_num=0)
313317

314318
@parameterized.expand([
319+
([2, 3, 4], [1, 2, 1], [1], [0, 2, 1], [0, 2, 1]),
315320
([2, 3, 4, 5], [1, 2, 1, 2], [1], [0, 2, 3, 1], [0, 3, 1, 2]),
316321
([2, 3, 4, 5], [1, 2, 1, 2], [1, 2], [0, 2, 3, 1], [0, 3, 1, 2]),
317322
([2, 3, 4, 5], [1, 2, 1, 2], [0, 1, 2, 3], [0, 2, 3, 1], [0, 3, 1, 2]),
@@ -351,6 +356,7 @@ def test_transpose_slice(self, input_shape, slice_size, axes, perm_input, perm_o
351356
model_proto, remaining_transpose_num=0)
352357

353358
@parameterized.expand([
359+
([2, 3, 4], [1, 2, 1], [1], [0, 2, 1], [0, 2, 1]),
354360
([2, 3, 4, 5], [1, 2, 1, 2], [1], [0, 2, 3, 1], [0, 3, 1, 2]),
355361
([2, 3, 4, 5], [1, 2, 1, 2], [1, 2], [0, 2, 3, 1], [0, 3, 1, 2]),
356362
([2, 3, 4, 5], [1, 2, 1, 2], [0, 1, 2, 3], [0, 2, 3, 1], [0, 3, 1, 2]),
@@ -390,6 +396,7 @@ def test_transpose_slice_opset_10(self, input_shape, slice_size, axes, perm_inpu
390396
model_proto, remaining_transpose_num=0)
391397

392398
@parameterized.expand([
399+
((2, 3, 4), (4, 2, 3), (2, 0, 1), (1, 2, 0)),
393400
((2, 3, 4, 5), (2, 4, 5, 3), [0, 2, 3, 1], [0, 3, 1, 2]),
394401
((2, 3, 4, 5, 6), (2, 4, 5, 6, 3), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
395402
])
@@ -513,6 +520,7 @@ def test_transpose_merge(self, input_shape1, input_shape2, perm):
513520

514521

515522
@parameterized.expand([
523+
((2, 3, 4), [2, 0, 1], [1, 2, 0]),
516524
((2, 3, 4, 5), [0, 2, 3, 1], [0, 3, 1, 2]),
517525
((2, 3, 4, 5, 6), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
518526
])
@@ -533,6 +541,7 @@ def test_transpose_mul_as_square(self, shape, perm_input, perm_output):
533541
model_proto, remaining_transpose_num=0)
534542

535543
@parameterized.expand([
544+
((2, 3, 4), [2, 0, 1], [1, 2, 0]),
536545
((2, 3, 4, 5), [0, 2, 3, 1], [0, 3, 1, 2]),
537546
((2, 3, 4, 5, 6), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
538547
])
@@ -555,6 +564,7 @@ def test_transpose_mul_broadcastable_const(self, shape, perm_input, perm_output)
555564
model_proto, remaining_transpose_num=0)
556565

557566
@parameterized.expand([
567+
((2, 3, 4), [2, 0, 1]),
558568
((2, 3, 4, 5), [0, 2, 3, 1]),
559569
((2, 3, 4, 5, 6), [0, 2, 3, 4, 1]),
560570
])
@@ -574,6 +584,7 @@ def test_transpose_with_shape(self, shape, perm):
574584
model_proto, remaining_transpose_num=0)
575585

576586
@parameterized.expand([
587+
((2, 3, 4), (4, 2, 3), [2, 0, 1]),
577588
((2, 3, 4, 5), (2, 4, 5, 3), [0, 2, 3, 1]),
578589
((2, 3, 4, 5, 6), (2, 4, 5, 6, 3), [0, 2, 3, 4, 1]),
579590
])
@@ -593,6 +604,7 @@ def test_transpose_with_identity(self, input_shape, output_shape, perm):
593604
model_proto, remaining_transpose_num=1)
594605

595606
@parameterized.expand([
607+
((2, 3, 4), [2, 0, 1], [1, 2, 0]),
596608
((2, 3, 4, 5), [0, 2, 3, 1], [0, 3, 1, 2]),
597609
((2, 3, 4, 5, 6), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
598610
])
@@ -613,6 +625,7 @@ def test_transpose_sqrt(self, shape, perm_input, perm_output):
613625
model_proto, remaining_transpose_num=0)
614626

615627
@parameterized.expand([
628+
((1, 3, 4), [4, 3], [0, 2, 1], [1, 0]),
616629
((1, 3, 4, 5), (4, 5, 3), [0, 2, 3, 1], [1, 2, 0]),
617630
((1, 3, 4, 5, 6), (4, 5, 6, 3), [0, 2, 3, 4, 1], [1, 2, 3, 0]),
618631
])
@@ -635,17 +648,18 @@ def test_transpose_with_squeeze1(self, input_shape, output_shape, perm, expected
635648
self.check_transpose_perm(model_after_opt, expected_perm)
636649

637650
@parameterized.expand([
638-
((1, 3, 4, 5), (1, 1, 4, 5, 1, 3, 1), [0, 2, 3, 1], [0, 1, 4, 5, 2, 3, 6]),
639-
((1, 3, 4, 5, 6), (1, 1, 4, 5, 1, 6, 1, 3), [0, 2, 3, 4, 1], [0, 1, 4, 5, 6, 7, 2, 3]),
651+
((1, 3, 4), (1, 4, 1, 3, 1, 1), [2, 0, 1], [0, 4, 5], [2, 3, 0, 1, 4, 5]),
652+
((1, 3, 4, 5), (1, 1, 4, 5, 1, 3, 1), [0, 2, 3, 1], [0, 4, 6], [0, 1, 4, 5, 2, 3, 6]),
653+
((1, 3, 4, 5, 6), (1, 1, 4, 5, 1, 6, 1, 3), [0, 2, 3, 4, 1], [0, 4, 6], [0, 1, 4, 5, 6, 7, 2, 3]),
640654
])
641-
def test_transpose_with_unsqueeze(self, input_shape, output_shape, perm, expected_perm):
655+
def test_transpose_with_unsqueeze(self, input_shape, output_shape, perm, axes_val, expected_perm):
642656
# unsqueeze the first dim
643657
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans")
644658
if self.config.opset <= 12:
645-
node2 = helper.make_node("Unsqueeze", ["Y"], ["Z"], name="unsqueeze", axes=[0, 4, 6])
659+
node2 = helper.make_node("Unsqueeze", ["Y"], ["Z"], name="unsqueeze", axes=axes_val)
646660
nodes = [node1, node2]
647661
else:
648-
axes = self._make_onnx_const(np.array([0, 4, 6], dtype=np.int64), "axes")
662+
axes = self._make_onnx_const(np.array(axes_val, dtype=np.int64), "axes")
649663
node2 = helper.make_node("Unsqueeze", ["Y", "axes"], ["Z"], name="unsqueeze")
650664
nodes = [axes, node1, node2]
651665

@@ -662,6 +676,7 @@ def test_transpose_with_unsqueeze(self, input_shape, output_shape, perm, expecte
662676
self.check_transpose_perm(model_after_opt, expected_perm)
663677

664678
@parameterized.expand([
679+
((1, 3, 4), [4, 3], [0, 2, 1], [1, 0]),
665680
((1, 3, 4, 5), (4, 5, 3), [0, 2, 3, 1], [1, 2, 0]),
666681
((1, 3, 4, 5, 6), (4, 5, 6, 3), [0, 2, 3, 4, 1], [1, 2, 3, 0]),
667682
])
@@ -816,6 +831,7 @@ def test_transpose_with_squeeze4_13(self, input_shape, output_shape, perm):
816831
model_proto, remaining_transpose_num=0)
817832

818833
@parameterized.expand([
834+
((10, 3, 4), [0, 2, 1], [0, 2, 1]),
819835
((10, 3, 4, 5), [0, 2, 3, 1], [0, 3, 1, 2]),
820836
((10, 3, 4, 5, 6), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
821837
])
@@ -880,6 +896,7 @@ def _make_loop(external_inputs, outputs):
880896
model_proto, remaining_transpose_num=0)
881897

882898
@parameterized.expand([
899+
((2, 3, 4), [4, 2, 3], [2, 0, 1], [1, 2, 0]),
883900
((2, 3, 4, 5), [2, 4, 5, 3], [0, 2, 3, 1], [0, 3, 1, 2]),
884901
((2, 3, 4, 5, 6), [2, 4, 5, 6, 3], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
885902
])
@@ -965,6 +982,7 @@ def test_transpose_add_with_input_non_const(self, input_shape1, input_shape2, pe
965982
model_proto, remaining_transpose_num=0)
966983

967984
@parameterized.expand([
985+
((2, 3, 4), [4, 2, 3], [2, 0, 1], [1, 2, 0]),
968986
((1, 1, 3, 3), (1, 3, 3, 1), [0, 2, 3, 1], [0, 3, 1, 2]),
969987
((1, 1, 3, 3, 3), (1, 3, 3, 3, 1), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
970988
])
@@ -1050,6 +1068,7 @@ def test_transpose_add_with_conv_2(self, input_shape, weights_shape, output_shap
10501068
model_proto, remaining_transpose_num=0)
10511069

10521070
@parameterized.expand([
1071+
((3, 4, 5), (8, 4, 6), [1, 3, 0, 0, 2, 0], [2, 0, 1], [1, 2, 0]),
10531072
((1, 3, 4, 5), (2, 6, 4, 8), [1, 0, 1, 3, 0, 0, 2, 0], [0, 2, 3, 1], [0, 3, 1, 2]),
10541073
((1, 3, 4, 5, 6), (2, 5, 6, 8, 10), [1, 0, 1, 3, 1, 0, 2, 2, 1, 1], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
10551074
])
@@ -1071,6 +1090,7 @@ def test_transpose_pad(self, input_shape, output_shape, pads, perm_input, perm_o
10711090
model_proto, remaining_transpose_num=0)
10721091

10731092
@parameterized.expand([
1093+
((3, 4, 5), (8, 4, 6), [1, 3, 0, 0, 2, 0], [2, 0, 1], [1, 2, 0]),
10741094
((1, 3, 4, 5), (2, 6, 4, 8), [1, 0, 1, 3, 0, 0, 2, 0], [0, 2, 3, 1], [0, 3, 1, 2]),
10751095
((1, 3, 4, 5, 6), (2, 5, 6, 8, 10), [1, 0, 1, 3, 1, 0, 2, 2, 1, 1], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
10761096
])
@@ -1097,6 +1117,7 @@ def test_transpose_pad11(self, input_shape, output_shape, pads, perm_input, perm
10971117
model_proto, remaining_transpose_num=0)
10981118

10991119
@parameterized.expand([
1120+
((3, 4, 5), (8, 4, 6), [1, 3, 0, 0, 2, 0], [2, 0, 1], [1, 2, 0]),
11001121
((1, 3, 4, 5), (2, 6, 4, 8), [1, 0, 1, 3, 0, 0, 2, 0], [0, 2, 3, 1], [0, 3, 1, 2]),
11011122
((1, 3, 4, 5, 6), (2, 5, 6, 8, 10), [1, 0, 1, 3, 1, 0, 2, 2, 1, 1], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
11021123
])
@@ -1125,6 +1146,7 @@ def test_transpose_pad11_non_const_pads(self, input_shape, output_shape, pads, p
11251146
model_proto, remaining_transpose_num=0)
11261147

11271148
@parameterized.expand([
1149+
((2, 3, 4), [2, 0, 1], [1, 2, 0]),
11281150
((2, 3, 4, 5), [0, 2, 3, 1], [0, 3, 1, 2]),
11291151
((2, 3, 4, 5, 6), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
11301152
])
@@ -1145,6 +1167,7 @@ def test_transpose_reciprocal(self, shape, perm_input, perm_output):
11451167
model_proto, remaining_transpose_num=0)
11461168

11471169
@parameterized.expand([
1170+
((3, 4, 5), (3, 4, 1), [0, 2, 1], [0, 2, 1]),
11481171
((1, 3, 4, 5), (1, 3, 1, 1), [0, 2, 3, 1], [0, 3, 1, 2]),
11491172
((1, 3, 4, 5, 6), (1, 3, 1, 1, 1), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
11501173
])
@@ -1166,6 +1189,7 @@ def test_transpose_reducemean(self, input_shape, output_shape, perm_input, perm_
11661189
model_proto, remaining_transpose_num=0)
11671190

11681191
@parameterized.expand([
1192+
((3, 4, 5), (3, 4, 1), [1], [0, 2, 1], [0, 2, 1]),
11691193
((1, 3, 4, 5), (1, 3, 4, 1), [2], [0, 2, 3, 1], [0, 3, 1, 2]),
11701194
((1, 3, 4, 5), (1, 3, 1, 1), [1, 2], [0, 2, 3, 1], [0, 3, 1, 2]),
11711195
((1, 3, 4, 5), (1, 1, 1, 1), [0, 1, 2, 3], [0, 2, 3, 1], [0, 3, 1, 2]),
@@ -1258,6 +1282,7 @@ def test_transpose_tile(self):
12581282
model_proto, remaining_transpose_num=0)
12591283

12601284
@parameterized.expand([
1285+
((3, 4, 5), (3, 4, 1), [1], [0, 2, 1], [0, 2, 1]),
12611286
((1, 3, 4, 5), (1, 3, 4, 1), [2], [0, 2, 3, 1], [0, 3, 1, 2]),
12621287
((1, 3, 4, 5), (1, 3, 1, 1), [1, 2], [0, 2, 3, 1], [0, 3, 1, 2]),
12631288
((1, 3, 4, 5), (1, 1, 1, 1), [0, 1, 2, 3], [0, 2, 3, 1], [0, 3, 1, 2]),
@@ -1286,6 +1311,7 @@ def test_transpose_reducesum_opset_13(self, input_shape, output_shape, axes, per
12861311
model_proto, remaining_transpose_num=0)
12871312

12881313
@parameterized.expand([
1314+
((2, 3, 4), (4, 2, 3), [2, 0, 1]),
12891315
((2, 3, 4, 5), (2, 4, 5, 3), [0, 2, 3, 1]),
12901316
((2, 3, 4, 5, 6), (2, 4, 5, 6, 3), [0, 2, 3, 4, 1]),
12911317
])
@@ -1362,6 +1388,7 @@ def test_trans_can_be_replaced_with_reshape2(self, input_shape_np, input_shape,
13621388
model_proto, remaining_transpose_num=0)
13631389

13641390
@parameterized.expand([
1391+
((1, 6, 8), [2, 0, 1], [1, 2, 0]),
13651392
((1, 6, 8, 9), [0, 2, 3, 1], [0, 3, 1, 2]),
13661393
((1, 6, 8, 9, 2), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
13671394
])
@@ -1388,6 +1415,7 @@ def test_two_transposes_switch_with_mul(self, shape, perm_input, perm_output):
13881415
model_proto, remaining_transpose_num=0)
13891416

13901417
@parameterized.expand([
1418+
((1, 6, 8), (8, 1, 6), [2, 0, 1], [1, 2, 0]),
13911419
((1, 6, 8, 9), (1, 8, 9, 6), [0, 2, 3, 1], [0, 3, 1, 2]),
13921420
((1, 6, 8, 9, 2), (1, 8, 9, 2, 6), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
13931421
])

0 commit comments

Comments
 (0)