Skip to content

Commit c9d6295

Browse files
committed
resolve requested changes
1 parent f6f8e17 commit c9d6295

File tree

5 files changed

+9
-13
lines changed

5 files changed

+9
-13
lines changed

tests/common.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@
3333
"validate_const_node",
3434
"group_nodes_by_type",
3535
"test_ms_domain",
36-
"check_node_domain"
36+
"check_node_domain",
37+
"check_thresholded_relu_count"
3738
]
3839

3940

@@ -305,6 +306,10 @@ def check_gru_count(graph, expected_count):
305306
return check_op_count(graph, "GRU", expected_count)
306307

307308

309+
def check_thresholded_relu_count(graph, expected_count):
310+
return check_op_count(graph, "ThresholdedRelu", expected_count)
311+
312+
308313
_MAX_MS_OPSET_VERSION = 1
309314

310315

tests/test_backend.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2148,7 +2148,8 @@ def test_thresholded_relu(self):
21482148
t = tf.keras.layers.ThresholdedReLU(theta=theta_val)
21492149
x_ = t.call(x)
21502150
_ = tf.identity(x_, name=_TFOUTPUT)
2151-
self._run_test_case([_OUTPUT], {_INPUT: x_val})
2151+
self._run_test_case([_OUTPUT], {_INPUT: x_val},
2152+
graph_validator=lambda g: check_thresholded_relu_count(g, 1))
21522153
tf.reset_default_graph()
21532154

21542155

tf2onnx/onnx_opset/math.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -399,10 +399,3 @@ def version_7(cls, ctx, node, **kwargs):
399399
ctx.remove_node(node.name)
400400
ctx.make_node(op_type="Sub", inputs=[node.input[0], mul.output[0]],
401401
name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
402-
403-
404-
@tf_op("ThresholdedReLU", onnx_op="ThresholdedRelu")
405-
class ThresholdedRelu:
406-
@classmethod
407-
def version_10(cls, ctx, node, **kwargs):
408-
pass

tf2onnx/rewriter/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from tf2onnx.rewriter.cond_rewriter import rewrite_cond
1010
from tf2onnx.rewriter.random_uniform import rewrite_random_uniform, rewrite_random_uniform_fold_const
1111
from tf2onnx.rewriter.leakyrelu_rewriter import rewrite_leakyrelu
12-
from tf2onnx.rewriter.thresholdedrelu_rewriter import rewrite_thresholded_relu
12+
from tf2onnx.rewriter.thresholded_relu_rewriter import rewrite_thresholded_relu
1313
from tf2onnx.rewriter.rnn import rewrite_single_direction_lstm, rewrite_bi_direction_lstm, \
1414
rewrite_single_direction_gru, rewrite_bi_direction_gru, \
1515
rewrite_custom_rnn_cell, rewrite_generic_loop

tf2onnx/rewriter/thresholdedrelu_rewriter.py renamed to tf2onnx/rewriter/thresholded_relu_rewriter.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,6 @@ def rewrite_thresholded_relu(g, ops):
4040
mul_input_edge_name = _find_edge_name_between_nodes(mul_input_node, mul_node)
4141
if greater_input_edge_name == mul_input_edge_name:
4242
theta = match.get_op('theta').get_tensor_value()
43-
# check disabled for now, tf requires theta to be non-negative, while onnx does not
44-
# if theta < 0:
45-
# continue
4643
thresholded_relu = g.make_node("ThresholdedRelu", inputs=[mul_input_edge_name], attr={"alpha": theta},
4744
shapes=[g.get_shape(mul_node.output[0])],
4845
dtypes=[g.get_dtype(mul_node.output[0])])

0 commit comments

Comments
 (0)