Skip to content

Commit 4cf284e

Browse files
committed
implement ThresholdedRelu for opset10
1 parent a4607b3 commit 4cf284e

File tree

5 files changed

+78
-1
lines changed

5 files changed

+78
-1
lines changed

tests/test_backend.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2138,5 +2138,19 @@ def test_non_max_suppression(self):
21382138
self._run_test_case([_OUTPUT, _OUTPUT1], {_INPUT: boxes_val, _INPUT1: scores_val})
21392139

21402140

2141+
@check_opset_min_version(10, "ThresholdedRelu")
2142+
def test_thresholded_relu(self):
2143+
# tf.keras.layers.ThresholdedReLU only supports `float32` for x
2144+
x_val = np.array([0.0, 1.0, -1.0, 2.0, -2.0, 0.5, -0.5, 1.5, -1.5], dtype=np.float32).reshape((3, 3))
2145+
theta_vals = [-1.0, -0.5, 0.0, 0.5, 1.0]
2146+
for theta_val in theta_vals:
2147+
x = tf.placeholder(x_val.dtype, x_val.shape, name=_TFINPUT)
2148+
t = tf.keras.layers.ThresholdedReLU(theta=theta_val)
2149+
x_ = t.call(x)
2150+
_ = tf.identity(x_, name=_TFOUTPUT)
2151+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
2152+
tf.reset_default_graph()
2153+
2154+
21412155
if __name__ == '__main__':
21422156
unittest_main()

tf2onnx/onnx_opset/math.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,3 +399,10 @@ def version_7(cls, ctx, node, **kwargs):
399399
ctx.remove_node(node.name)
400400
ctx.make_node(op_type="Sub", inputs=[node.input[0], mul.output[0]],
401401
name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
402+
403+
404+
@tf_op("ThresholdedReLU", onnx_op="ThresholdedRelu")
405+
class ThresholdedRelu:
406+
@classmethod
407+
def version_10(cls, ctx, node, **kwargs):
408+
pass

tf2onnx/rewriter/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from tf2onnx.rewriter.cond_rewriter import rewrite_cond
1010
from tf2onnx.rewriter.random_uniform import rewrite_random_uniform, rewrite_random_uniform_fold_const
1111
from tf2onnx.rewriter.leakyrelu_rewriter import rewrite_leakyrelu
12+
from tf2onnx.rewriter.thresholdedrelu_rewriter import rewrite_thresholded_relu
1213
from tf2onnx.rewriter.rnn import rewrite_single_direction_lstm, rewrite_bi_direction_lstm, \
1314
rewrite_single_direction_gru, rewrite_bi_direction_gru, \
1415
rewrite_custom_rnn_cell, rewrite_generic_loop
@@ -18,6 +19,7 @@
1819
"rewrite_random_uniform",
1920
"rewrite_random_uniform_fold_const",
2021
"rewrite_leakyrelu",
22+
"rewrite_thresholded_relu",
2123
"rewrite_single_direction_lstm",
2224
"rewrite_bi_direction_lstm",
2325
"rewrite_single_direction_gru",
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
tf2onnx.rewriter - rewrite tensorflow subgraph to onnx ThresholdedRelu op
6+
"""
7+
8+
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
9+
from tf2onnx.rewriter.leakyrelu_rewriter import _find_edge_name_between_nodes
10+
11+
12+
# pylint: disable=missing-docstring
13+
14+
15+
def rewrite_thresholded_relu(g, ops):
16+
if g.opset < 10:
17+
return ops
18+
19+
pattern = \
20+
OpTypePattern('Mul', name='mul', inputs=[
21+
OpTypePattern('Cast', name='cast', inputs=[
22+
OpTypePattern('Greater', name='greater', inputs=[
23+
OpTypePattern('*', name='greater_input'),
24+
OpTypePattern('Const', name='theta')
25+
])
26+
]),
27+
OpTypePattern('*', name='mul_input')
28+
])
29+
matcher = GraphMatcher(pattern, allow_reorder=True)
30+
match_results = list(matcher.match_ops(ops))
31+
32+
for match in match_results:
33+
greater_node = match.get_op('greater')
34+
greater_input_node = match.get_op('greater_input')
35+
mul_node = match.get_op("mul")
36+
mul_input_node = match.get_op('mul_input')
37+
cast_node = match.get_op('cast')
38+
39+
greater_input_edge_name = _find_edge_name_between_nodes(greater_input_node, greater_node)
40+
mul_input_edge_name = _find_edge_name_between_nodes(mul_input_node, mul_node)
41+
if greater_input_edge_name == mul_input_edge_name:
42+
theta = match.get_op('theta').get_tensor_value()
43+
# check disabled for now, tf requires theta to be non-negative, while onnx does not
44+
# if theta < 0:
45+
# continue
46+
thresholded_relu = g.make_node("ThresholdedRelu", inputs=[mul_input_edge_name], attr={"alpha": theta},
47+
shapes=[g.get_shape(mul_node.output[0])], dtypes=[g.get_dtype(mul_node.output[0])])
48+
ops.remove(greater_node)
49+
ops.remove(cast_node)
50+
ops.remove(mul_node)
51+
ops.append(thresholded_relu)
52+
g.replace_all_inputs(ops, mul_node.output[0], thresholded_relu.output[0])
53+
return ops
54+

tf2onnx/tfonnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -772,7 +772,7 @@ def compat_handler(ctx, node, **kwargs):
772772
rewriters = [rewrite_transpose, rewrite_flatten,
773773
rewrite_random_uniform, rewrite_random_uniform_fold_const,
774774
rewrite_random_normal, rewrite_dropout,
775-
rewrite_leakyrelu, rewrite_conv2d_with_pad,
775+
rewrite_leakyrelu, rewrite_thresholded_relu, rewrite_conv2d_with_pad,
776776
rewrite_single_direction_lstm, rewrite_bi_direction_lstm,
777777
rewrite_single_direction_gru, rewrite_bi_direction_gru,
778778
rewrite_custom_rnn_cell, rewrite_generic_loop, rewrite_cond

0 commit comments

Comments
 (0)