Skip to content

Commit 06d2fc9

Browse files
committed
Add Selu operator support
Add Selu operator support to partially address issue: #424 Set the version to 4 like most others.
1 parent 25973f3 commit 06d2fc9

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

tests/test_backend.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2266,5 +2266,13 @@ def test_maxpoolwithargmax(self):
22662266
self.logger.debug(str(p))
22672267
self._run_test_case([_OUTPUT, _OUTPUT1], {_INPUT: x_val})
22682268

2269+
@check_opset_min_version(10, "Selu")
2270+
def test_selu(self):
2271+
x_val = np.random.random_sample([3]).astype(np.float32)
2272+
x = tf.placeholder(x_val.dtype, x_val.shape, name=_TFINPUT)
2273+
y = tf.nn.selu(x)
2274+
_ = tf.identity(y, name=_TFOUTPUT)
2275+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
2276+
22692277
if __name__ == '__main__':
22702278
unittest_main()

tf2onnx/onnx_opset/math.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,3 +399,9 @@ def version_7(cls, ctx, node, **kwargs):
399399
ctx.remove_node(node.name)
400400
ctx.make_node(op_type="Sub", inputs=[node.input[0], mul.output[0]],
401401
name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
402+
403+
@tf_op("Selu")
404+
class Selu:
405+
@classmethod
406+
def version_4(cls, ctx, node, **kwargs):
407+
pass

0 commit comments

Comments
 (0)