Skip to content

Commit 7fbab3e

Browse files
authored
Merge pull request #633 from lei-Qiao/transpose_optimizer
Transpose optimizer
2 parents c6d531c + 71f4d68 commit 7fbab3e

File tree

2 files changed

+185
-31
lines changed

2 files changed

+185
-31
lines changed

tests/test_optimizers.py

Lines changed: 175 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,12 +184,12 @@ def test_transpose_max(self):
184184
const_1 = helper.make_tensor("const_1", TensorProto.FLOAT, (1,), const_1_val)
185185
const_1_node = helper.make_node("Constant", [], ["const_1"], value=const_1, name="const_1")
186186

187-
const_2_val = np.random.randn(2, 4, 5, 3).astype(np.float32).reshape(120).tolist()
188-
const_2 = helper.make_tensor("const_2", TensorProto.FLOAT, (2, 4, 5, 3), const_2_val)
187+
const_2_val = np.random.randn(2, 4, 5, 3).astype(np.float32)
188+
const_2 = helper.make_tensor("const_2", TensorProto.FLOAT, (2, 4, 5, 3), const_2_val.flatten())
189189
const_2_node = helper.make_node("Constant", [], ["const_2"], value=const_2, name="const_2")
190190

191-
const_3_val = np.random.randn(2, 4, 5, 3).astype(np.float32).reshape(120).tolist()
192-
const_3 = helper.make_tensor("const_3", TensorProto.FLOAT, (2, 4, 5, 3), const_3_val)
191+
const_3_val = np.random.randn(2, 4, 5, 3).astype(np.float32)
192+
const_3 = helper.make_tensor("const_3", TensorProto.FLOAT, (2, 4, 5, 3), const_3_val.flatten())
193193
const_3_node = helper.make_node("Constant", [], ["const_3"], value=const_3, name="const_3")
194194

195195
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
@@ -207,6 +207,32 @@ def test_transpose_max(self):
207207
self.run_transpose_compare(["Z1"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
208208
model_proto, remaining_transpose_num=0)
209209

210+
def test_transpose_max_input_non_const(self):
211+
const_1_val = [2.0]
212+
const_1 = helper.make_tensor("const_1", TensorProto.FLOAT, (1,), const_1_val)
213+
const_1_node = helper.make_node("Constant", [], ["const_1"], value=const_1, name="const_1")
214+
215+
const_2_val = np.random.randn(2, 4, 5, 3).astype(np.float32)
216+
const_2 = helper.make_tensor("const_2", TensorProto.FLOAT, (2, 4, 5, 3), const_2_val.flatten())
217+
const_2_node = helper.make_node("Constant", [], ["const_2"], value=const_2, name="const_2")
218+
219+
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
220+
node2 = helper.make_node("Max", ["Y", "non_const", "const_2", "const_1"], ["Z"], name="max")
221+
node3 = helper.make_node("Transpose", ["Z"], ["Z1"], perm=[0, 3, 1, 2], name="trans_2")
222+
223+
graph = helper.make_graph(
224+
[const_1_node, const_2_node, node1, node2, node3],
225+
"Max-test",
226+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3, 4, 5)),
227+
helper.make_tensor_value_info("non_const", TensorProto.FLOAT, (2, 4, 5, 3))],
228+
[helper.make_tensor_value_info("Z1", TensorProto.FLOAT, (2, 3, 4, 5))],
229+
)
230+
231+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
232+
self.run_transpose_compare(["Z1"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32),
233+
"non_const": np.random.randn(2, 4, 5, 3).astype(np.float32)},
234+
model_proto, remaining_transpose_num=1)
235+
210236
def test_transpose_merge(self):
211237
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans")
212238
node1 = helper.make_node("Transpose", ["X"], ["Y_1"], perm=[0, 2, 3, 1], name="trans_1")
@@ -394,6 +420,151 @@ def test_trans_with_sub(self):
394420
self.run_transpose_compare(["res"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
395421
model_proto, remaining_transpose_num=0)
396422

423+
def test_trans_with_sub_input_non_const(self):
424+
io_shape = [2, 3, 4, 5]
425+
non_const_shapes = [[2, 4, 5, 3], [4, 5, 3], [5, 3]]
426+
for trans_is_first_input in [True, False]:
427+
for non_const_shape in non_const_shapes:
428+
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_a")
429+
if trans_is_first_input:
430+
node2 = helper.make_node("Sub", ["Y", "non_const"], ["Z"], name="sub")
431+
else:
432+
node2 = helper.make_node("Sub", ["non_const", "Y"], ["Z"], name="sub")
433+
434+
node3 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans_b")
435+
graph = helper.make_graph(
436+
[node1, node2, node3],
437+
"test_trans_with_sub_input_non_const",
438+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, io_shape),
439+
helper.make_tensor_value_info("non_const", TensorProto.FLOAT, non_const_shape)],
440+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, io_shape)],
441+
)
442+
443+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
444+
self.run_transpose_compare(["res"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32),
445+
"non_const": np.random.randn(*non_const_shape).astype(np.float32)},
446+
model_proto, remaining_transpose_num=1)
447+
448+
def test_transpose_add_with_input_non_const(self):
449+
450+
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
451+
node1 = helper.make_node("Add", ["Y", "A"], ["Z"], name="add")
452+
node2 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans_2")
453+
454+
graph = helper.make_graph(
455+
[node0, node1, node2],
456+
"transpose-add-test-input-non-const",
457+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (1, 1, 3, 3)),
458+
helper.make_tensor_value_info("A", TensorProto.FLOAT, (1, 3, 3, 1))],
459+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, (1, 1, 3, 3))],
460+
)
461+
462+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
463+
self.run_transpose_compare(["res"], {"X": np.random.randn(1, 1, 3, 3).astype(np.float32),
464+
"A": np.random.randn(1, 3, 3, 1).astype(np.float32)},
465+
model_proto, remaining_transpose_num=0)
466+
467+
def test_transpose_add_with_input_const(self):
468+
const_1_val = np.random.randn(1, 3, 3, 1).astype(np.float32)
469+
const_1 = helper.make_tensor("const_1", TensorProto.FLOAT, (1, 3, 3, 1), const_1_val.flatten())
470+
const_1_node = helper.make_node("Constant", [], ["const_1"], value=const_1, name="const_1")
471+
472+
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
473+
node1 = helper.make_node("Add", ["Y", "const_1"], ["Z"], name="add")
474+
node2 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans_2")
475+
476+
graph = helper.make_graph(
477+
[const_1_node, node0, node1, node2],
478+
"transpose-add-test-input-const",
479+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (1, 1, 3, 3))],
480+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, (1, 1, 3, 3))],
481+
)
482+
483+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
484+
self.run_transpose_compare(["res"], {"X": np.random.randn(1, 1, 3, 3).astype(np.float32)},
485+
model_proto, remaining_transpose_num=0)
486+
487+
def test_transpose_add_with_conv_1(self):
488+
# case where bias's dim is 1D and can be merged into Conv
489+
const_b_val = np.random.randn(1, 1, 1, 16).astype(np.float32)
490+
const_b = helper.make_tensor("const_b", TensorProto.FLOAT, (1, 1, 1, 16), const_b_val.flatten())
491+
const_b_node = helper.make_node("Constant", [], ["const_b"], value=const_b, name="const_b")
492+
493+
node0 = helper.make_node("Conv", ["x", "W"], ["X"], name="conv", pads=[0, 0, 0, 0])
494+
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
495+
node2 = helper.make_node("Add", ["Y", "const_b"], ["Z"], name="add")
496+
node3 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans_2")
497+
498+
graph = helper.make_graph(
499+
[const_b_node, node0, node1, node2, node3],
500+
"transpose-add-test-with-conv-1",
501+
[helper.make_tensor_value_info("x", TensorProto.FLOAT, (1, 5, 3, 3)),
502+
helper.make_tensor_value_info("W", TensorProto.FLOAT, (16, 5, 3, 3))],
503+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, (1, 16, 1, 1))],
504+
)
505+
506+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
507+
self.run_transpose_compare(["res"], {"x": np.random.randn(1, 5, 3, 3).astype(np.float32),
508+
"W": np.random.randn(16, 5, 3, 3).astype(np.float32)},
509+
model_proto, remaining_transpose_num=0)
510+
511+
def test_transpose_add_with_conv_2(self):
512+
# case where bias's dim is not 1D and can't be merged into Conv
513+
# add handler just remove the transpose around Add node
514+
const_b_val = np.random.randn(1, 3, 3, 1).astype(np.float32)
515+
const_b = helper.make_tensor("const_b", TensorProto.FLOAT, (1, 3, 3, 1), const_b_val.flatten())
516+
const_b_node = helper.make_node("Constant", [], ["const_b"], value=const_b, name="const_b")
517+
518+
node0 = helper.make_node("Conv", ["x", "W"], ["X"], name="conv", pads=[0, 0, 0, 0])
519+
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
520+
node2 = helper.make_node("Add", ["Y", "const_b"], ["Z"], name="add")
521+
node3 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans_2")
522+
523+
graph = helper.make_graph(
524+
[const_b_node, node0, node1, node2, node3],
525+
"transpose-add-test-with-conv-2",
526+
[helper.make_tensor_value_info("x", TensorProto.FLOAT, (1, 1, 5, 5)),
527+
helper.make_tensor_value_info("W", TensorProto.FLOAT, (1, 1, 3, 3))],
528+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, (1, 1, 3, 3))],
529+
)
530+
531+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
532+
self.run_transpose_compare(["res"], {"x": np.random.randn(1, 1, 5, 5).astype(np.float32),
533+
"W": np.random.randn(1, 1, 3, 3).astype(np.float32)},
534+
model_proto, remaining_transpose_num=0)
535+
536+
def test_transpose_pad(self):
537+
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
538+
node1 = helper.make_node("Pad", ["Y"], ["Z"], pads=[1, 0, 1, 3, 0, 0, 2, 0], name="pad")
539+
node2 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans_2")
540+
541+
graph = helper.make_graph(
542+
[node0, node1, node2],
543+
"transpose-pad-test",
544+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (1, 3, 4, 5))],
545+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, (2, 6, 4, 8))],
546+
)
547+
548+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
549+
self.run_transpose_compare(["res"], {"X": np.random.randn(1, 3, 4, 5).astype(np.float32)},
550+
model_proto, remaining_transpose_num=0)
551+
552+
def test_transpose_reducemean(self):
553+
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
554+
node1 = helper.make_node("ReduceMean", ["Y"], ["Z"], axes=[1, 2], keepdims=1, name="reducemean")
555+
node2 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans_2")
556+
557+
graph = helper.make_graph(
558+
[node0, node1, node2],
559+
"transpose-reducemean-test",
560+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (1, 3, 4, 5))],
561+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, (1, 3, 1, 1))],
562+
)
563+
564+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
565+
self.run_transpose_compare(["res"], {"X": np.random.randn(1, 3, 4, 5).astype(np.float32)},
566+
model_proto, remaining_transpose_num=0)
567+
397568
def test_trans_output_as_graph_outputs(self):
398569
"""
399570
If transpose's output is graph's output, don't optimize it.

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -370,16 +370,23 @@ def _add_handler(self, trans, node):
370370
if t_p.type in ("Conv", "ConvTranspose") and len(t_p.input) == 2:
371371
# if Conv or ConvTranspose's bias input is not set, then we set, otherwise, we don't set
372372
# todo: maybe we can add already set bias with the input??? try later
373+
374+
target_node = node.inputs[1]
375+
numpy_val = target_node.get_tensor_value(as_list=False)
376+
# Optional 1D bias to be added to the convolution, has size of M
377+
if len(numpy_val.shape) - numpy_val.shape.count(1) > 1:
378+
return self._handle_node_having_branches(node)
379+
transposed_val = np.transpose(numpy_val, (0, 3, 1, 2))
380+
target_node.set_tensor_value(transposed_val)
381+
373382
conv_inputs = [t_p.input[0], t_p.input[1], node.input[1]]
374383
conv_node = self._g.make_node(t_p.type, conv_inputs, attr=t_p.attr_onnx)
375384
ops = self._g.get_nodes()
376385
trans.input[0] = utils.port_name(conv_node.name)
377386
self._g.replace_all_inputs(ops, node.output[0], trans.output[0])
378-
379387
self._g.remove_node(t_p.name)
380388
self._g.remove_node(node.name)
381389
return True
382-
return False
383390
return self._handle_node_having_branches(node)
384391

385392
def _transpose_handler(self, trans, node):
@@ -398,31 +405,7 @@ def _transpose_handler(self, trans, node):
398405
return False
399406

400407
def _maxmin_handler(self, trans, node):
401-
input_index = self._get_input_index_for_trans(node, trans)
402-
all_other_inputs = [input_id for i, input_id in enumerate(node.input) if i != input_index]
403-
404-
all_other_inputs_const = all([self._g.get_node_by_output(i).is_const() for i in all_other_inputs])
405-
if all_other_inputs_const is False:
406-
return False
407-
408-
shapes = [len(self._g.get_shape(i)) for i in all_other_inputs]
409-
shapes_not_one_and_four = [s for s in shapes if s not in [1, 4]]
410-
if shapes_not_one_and_four:
411-
return False
412-
413-
for i in all_other_inputs:
414-
target_node = self._g.get_node_by_output(i)
415-
numpy_val = target_node.get_tensor_value(as_list=False)
416-
rank = numpy_val.ndim
417-
if rank == 4:
418-
transposed_val = np.transpose(numpy_val, (0, 3, 1, 2))
419-
target_node.set_tensor_value(transposed_val)
420-
elif rank == 1: # scalar
421-
# do nothing
422-
pass
423-
else:
424-
raise ValueError("find rank !=1 and rank !=4, should not go here.")
425-
return self._switch_transpose_and_node(node, trans)
408+
return self._handle_node_having_branches(node)
426409

427410
def _mul_handler(self, trans, node):
428411
multiplier_input_id = None

0 commit comments

Comments
 (0)