Skip to content

Commit 9ca2522

Browse files
Add PRelu to Transpose optimizer (#1630)
Signed-off-by: Tom Wildenhain <[email protected]> Co-authored-by: Guenther Schmuelling <[email protected]>
1 parent 760a555 commit 9ca2522

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

tests/test_optimizers.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,31 @@ def test_transpose_leaky_relu(self, shape, perm_input, perm_output):
211211
self.run_transpose_compare(["Z1"], {"X": np.random.randn(*shape).astype(np.float32)},
212212
model_proto, remaining_transpose_num=0)
213213

214+
@parameterized.expand([
215+
((2, 3, 4), [2, 0, 1], [1, 2, 0]),
216+
((2, 3, 4, 5), [0, 2, 3, 1], [0, 3, 1, 2]),
217+
((2, 3, 4, 5, 6), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
218+
])
219+
def test_transpose_with_prelu(self, input_shape, perm_input, perm_output):
220+
node1 = helper.make_node("Transpose", ["input_data1"], ["Y"], perm=perm_input, name="trans")
221+
node2 = helper.make_node("PRelu", ["Y", "input_data2"], ["Z"], name="add")
222+
node3 = helper.make_node("Transpose", ["Z"], ["res"], perm=perm_output, name="trans2")
223+
224+
graph = helper.make_graph(
225+
[node1, node2, node3],
226+
"transpose_with_shape",
227+
[helper.make_tensor_value_info("input_data1", TensorProto.FLOAT, input_shape),
228+
helper.make_tensor_value_info("input_data2", TensorProto.FLOAT, (input_shape[1],)),
229+
],
230+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, input_shape)],
231+
)
232+
233+
model_proto = self.make_model(graph, producer_name="onnx-tests")
234+
feed_dict = {"input_data1": np.random.randn(*input_shape).astype(np.float32),
235+
"input_data2": np.random.randn(input_shape[1]).astype(np.float32),
236+
}
237+
self.run_transpose_compare(["res"], feed_dict, model_proto, remaining_transpose_num=0)
238+
214239
@parameterized.expand([
215240
((2, 3, 4), [2, 0, 1], [1, 2, 0]),
216241
((2, 3, 4, 5), [0, 2, 3, 1], [0, 3, 1, 2]),

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ def _initialize_handlers(self):
208208
"Min": self._maxmin_handler,
209209
"Mul": self._mul_handler,
210210
"Pad": self._pad_handler,
211+
"PRelu": self._prelu_handler,
211212
"Reciprocal": self._simple_through_handler,
212213
"ReduceLogSum": self._reduce_handler,
213214
"ReduceLogSumExp": self._reduce_handler,
@@ -816,6 +817,9 @@ def permute_pads(pads):
816817
self._g.replace_input(node, node.input[1], new_pads.output[0], 1)
817818
return self._switch_transpose_and_node(node, trans)
818819

820+
def _prelu_handler(self, trans, node):
821+
return self._handle_node_having_branches(trans, node)
822+
819823
def _arg_min_max_handler(self, trans, node):
820824
axis = node.get_attr_value("axis", 0)
821825
node.set_attr("axes", [axis])

0 commit comments

Comments
 (0)