Skip to content

Commit c4c3ccd

Browse files
authored
Merge pull request onnx#420 from winnietsang/add-thresholdedRelu
Add opset 10 support for thresholded-relu
2 parents c17b845 + c8364c0 commit c4c3ccd

File tree

4 files changed

+21
-12
lines changed

4 files changed

+21
-12
lines changed

doc/support_status.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ ______
153153
|Tan|7|
154154
|Tanh|1, 6|
155155
|TfIdfVectorizer|N/A|
156-
|ThresholdedRelu|1|
156+
|ThresholdedRelu|1, 10|
157157
|Tile|1, 6|
158158
|TopK|1, 10|
159159
|Transpose|1|

onnx_tf/handlers/backend/thresholded_relu.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,16 @@
1010
class ThresholdedRelu(BackendHandler):
1111

1212
@classmethod
13-
def get_attrs_processor_param(cls):
14-
return {"rename": {"alpha": "theta"}}
13+
def _common(cls, node, **kwargs):
14+
x = kwargs["tensor_dict"][node.inputs[0]]
15+
alpha = node.attrs.get("alpha", 1.0)
16+
epsilon = 1e-5
17+
return [tf.nn.relu(x) - tf.nn.relu(tf.sign(alpha - x + epsilon) * x)]
1518

1619
@classmethod
1720
def version_1(cls, node, **kwargs):
18-
x = kwargs["tensor_dict"][node.inputs[0]]
19-
if "alpha" not in node.attrs.keys():
20-
warnings.warn("Provide an alpha value.", UserWarning)
21-
alpha = 1
22-
else:
23-
alpha = node.attrs["alpha"]
21+
return cls._common(node, **kwargs)
2422

25-
epsilon = 1e-5
26-
return [tf.nn.relu(x) - tf.nn.relu(tf.sign(alpha - x + epsilon) * x)]
23+
@classmethod
24+
def version_10(cls, node, **kwargs):
25+
return cls._common(node, **kwargs)

onnx_tf/opset_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@
146146
'Tan': [7],
147147
'Tanh': [1, 6],
148148
'TfIdfVectorizer': [],
149-
'ThresholdedRelu': [1],
149+
'ThresholdedRelu': [1, 10],
150150
'Tile': [1, 6],
151151
'TopK': [1, 10],
152152
'Transpose': [1],

test/backend/test_node.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,6 +1071,16 @@ def test_tanh(self):
10711071
output = run_node(node_def, [x])
10721072
np.testing.assert_almost_equal(output["Y"], np.tanh(x), decimal=5)
10731073

1074+
def test_thresholded_relu(self):
1075+
alpha = 2.0
1076+
node_def = helper.make_node(
1077+
"ThresholdedRelu", ["X"], ["Y"], alpha=alpha)
1078+
x = self._get_rnd([10], -3.0, 3.0)
1079+
y = np.clip(x, alpha, np.inf)
1080+
y[y == alpha] = 0
1081+
output = run_node(node_def, [x])
1082+
np.testing.assert_almost_equal(output["Y"], y)
1083+
10741084
def test_tile(self):
10751085
if legacy_onnx_pre_ver(1, 2):
10761086
raise unittest.SkipTest(

0 commit comments

Comments
 (0)