Skip to content

Commit 83cf13a

Browse files
authored
Merge pull request #492 from mindest/thresholdedrelu_opset10
implement ThresholdedRelu for opset10
2 parents 395f1a9 + 3169bc1 commit 83cf13a

File tree

5 files changed

+71
-2
lines changed

5 files changed

+71
-2
lines changed

tests/common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@
3333
"validate_const_node",
3434
"group_nodes_by_type",
3535
"test_ms_domain",
36-
"check_node_domain"
36+
"check_node_domain",
37+
"check_op_count"
3738
]
3839

3940

tests/test_backend.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2167,5 +2167,20 @@ def test_conv1d_5(self):
21672167
w = np.array([3., 3., 3.], dtype=np.float32).reshape(3, 1, 1)
21682168
self._conv1d_test(x_val, w)
21692169

2170+
@check_opset_min_version(10, "ThresholdedRelu")
2171+
def test_thresholded_relu(self):
2172+
# tf.keras.layers.ThresholdedReLU only supports `float32` for x
2173+
x_val = np.array([0.0, 1.0, -1.0, 2.0, -2.0, 0.5, -0.5, 1.5, -1.5], dtype=np.float32).reshape((3, 3))
2174+
theta_vals = [-1.0, -0.5, 0.0, 0.5, 1.0]
2175+
for theta_val in theta_vals:
2176+
x = tf.placeholder(x_val.dtype, x_val.shape, name=_TFINPUT)
2177+
t = tf.keras.layers.ThresholdedReLU(theta=theta_val)
2178+
x_ = t.call(x)
2179+
_ = tf.identity(x_, name=_TFOUTPUT)
2180+
self._run_test_case([_OUTPUT], {_INPUT: x_val},
2181+
graph_validator=lambda g: check_op_count(g, "ThresholdedRelu", 1))
2182+
tf.reset_default_graph()
2183+
2184+
21702185
if __name__ == '__main__':
21712186
unittest_main()

tf2onnx/rewriter/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from tf2onnx.rewriter.cond_rewriter import rewrite_cond
1010
from tf2onnx.rewriter.random_uniform import rewrite_random_uniform, rewrite_random_uniform_fold_const
1111
from tf2onnx.rewriter.leakyrelu_rewriter import rewrite_leakyrelu
12+
from tf2onnx.rewriter.thresholded_relu_rewriter import rewrite_thresholded_relu
1213
from tf2onnx.rewriter.rnn import rewrite_single_direction_lstm, rewrite_bi_direction_lstm, \
1314
rewrite_single_direction_gru, rewrite_bi_direction_gru, \
1415
rewrite_custom_rnn_cell, rewrite_generic_loop
@@ -18,6 +19,7 @@
1819
"rewrite_random_uniform",
1920
"rewrite_random_uniform_fold_const",
2021
"rewrite_leakyrelu",
22+
"rewrite_thresholded_relu",
2123
"rewrite_single_direction_lstm",
2224
"rewrite_bi_direction_lstm",
2325
"rewrite_single_direction_gru",
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
tf2onnx.rewriter - rewrite tensorflow subgraph to onnx ThresholdedRelu op
6+
"""
7+
8+
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
9+
from tf2onnx.rewriter.leakyrelu_rewriter import _find_edge_name_between_nodes
10+
11+
12+
# pylint: disable=missing-docstring
13+
14+
15+
def rewrite_thresholded_relu(g, ops):
16+
if g.opset < 10:
17+
return ops
18+
19+
pattern = \
20+
OpTypePattern('Mul', name='mul', inputs=[
21+
OpTypePattern('Cast', name='cast', inputs=[
22+
OpTypePattern('Greater', name='greater', inputs=[
23+
OpTypePattern('*', name='greater_input'),
24+
OpTypePattern('Const', name='theta')
25+
])
26+
]),
27+
OpTypePattern('*', name='mul_input')
28+
])
29+
matcher = GraphMatcher(pattern, allow_reorder=True)
30+
match_results = list(matcher.match_ops(ops))
31+
32+
for match in match_results:
33+
greater_node = match.get_op('greater')
34+
greater_input_node = match.get_op('greater_input')
35+
mul_node = match.get_op("mul")
36+
mul_input_node = match.get_op('mul_input')
37+
cast_node = match.get_op('cast')
38+
39+
greater_input_edge_name = _find_edge_name_between_nodes(greater_input_node, greater_node)
40+
mul_input_edge_name = _find_edge_name_between_nodes(mul_input_node, mul_node)
41+
if greater_input_edge_name == mul_input_edge_name:
42+
theta = match.get_op('theta').get_tensor_value()
43+
thresholded_relu = g.make_node("ThresholdedRelu", inputs=[mul_input_edge_name], attr={"alpha": theta},
44+
shapes=[g.get_shape(mul_node.output[0])],
45+
dtypes=[g.get_dtype(mul_node.output[0])])
46+
ops.remove(greater_node)
47+
ops.remove(cast_node)
48+
ops.remove(mul_node)
49+
ops.append(thresholded_relu)
50+
g.replace_all_inputs(ops, mul_node.output[0], thresholded_relu.output[0])
51+
return ops

tf2onnx/tfonnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -772,7 +772,7 @@ def compat_handler(ctx, node, **kwargs):
772772
rewriters = [rewrite_transpose, rewrite_flatten,
773773
rewrite_random_uniform, rewrite_random_uniform_fold_const,
774774
rewrite_random_normal, rewrite_dropout,
775-
rewrite_leakyrelu, rewrite_conv2d_with_pad,
775+
rewrite_leakyrelu, rewrite_thresholded_relu, rewrite_conv2d_with_pad,
776776
rewrite_single_direction_lstm, rewrite_bi_direction_lstm,
777777
rewrite_single_direction_gru, rewrite_bi_direction_gru,
778778
rewrite_custom_rnn_cell, rewrite_generic_loop, rewrite_cond

0 commit comments

Comments
 (0)