Skip to content

Commit 2f577d5

Browse files
authored
Merge pull request #435 from chinhuang007/add-tf-onnx-ops
add direct ops from onnx-tf
2 parents d43b585 + 2bcaa5f commit 2f577d5

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

tests/test_backend.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1902,6 +1902,26 @@ def test_isnan(self):
19021902
self._run_test_case([_OUTPUT], {_INPUT: x_val})
19031903
tf.reset_default_graph()
19041904

1905+
def test_ceil(self):
1906+
x_val = np.array([-1.5, 1.2], dtype=np.float32)
1907+
x = tf.placeholder(tf.float32, [2], name=_TFINPUT)
1908+
x_ = tf.ceil(x)
1909+
_ = tf.identity(x_, name=_TFOUTPUT)
1910+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1911+
1912+
def test_softplus(self):
1913+
x_val = np.array([-1, 0, 1], dtype=np.float32)
1914+
x = tf.placeholder(tf.float32, [3], name=_TFINPUT)
1915+
x_ = tf.math.softplus(x)
1916+
_ = tf.identity(x_, name=_TFOUTPUT)
1917+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1918+
1919+
def test_softsign(self):
1920+
x_val = np.array([-1, 0, 1], dtype=np.float32)
1921+
x = tf.placeholder(tf.float32, [3], name=_TFINPUT)
1922+
x_ = tf.math.softsign(x)
1923+
_ = tf.identity(x_, name=_TFOUTPUT)
1924+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
19051925

19061926
if __name__ == '__main__':
19071927
unittest_main()

tf2onnx/tfonnx.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1756,6 +1756,7 @@ def where_op(ctx, node, name, args):
17561756
"BiasAdd": (biasadd_op, []),
17571757
"BiasAddV1": (biasadd_op, []),
17581758
"Cast": (cast_op, []),
1759+
"Ceil": (direct_op, []),
17591760
"CheckNumerics": (identity_op, ["Identity"]),
17601761
"Concat": (concat_op, ["Concat"]),
17611762
"ConcatV2": (concatv2_op, ["Concat"]),
@@ -1828,6 +1829,8 @@ def where_op(ctx, node, name, args):
18281829
"Square": (square_op, []),
18291830
"SquaredDifference": (squareddifference_op, []),
18301831
"Softmax": (softmax_op, ["Softmax"]),
1832+
"Softplus": (direct_op, []),
1833+
"Softsign": (direct_op, []),
18311834
"StopGradient": (identity_op, ["Identity"]),
18321835
"StridedSlice": (stridedslice_op, []),
18331836
"Sub": (broadcast_op, []),

0 commit comments

Comments
 (0)