Skip to content

Commit 8a3164f

Browse files
committed
opset 12 support
1 parent 212e119 commit 8a3164f

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

tests/test_backend.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3130,6 +3130,29 @@ def func(X, K):
31303130
k_val = np.array(raw_k).astype(np.int32)
31313131
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: x_val, _INPUT1: k_val})
31323132

3133+
@check_opset_min_version(12)
3134+
def test_inverse(self):
3135+
x_val = np.random.random([5, 5]).astype(np.float32)
3136+
def func(x):
3137+
return tf.linalg.inv(x, name=_TFOUTPUT)
3138+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
3139+
3140+
@check_opset_min_version(12)
3141+
def test_less_or_equal(self):
3142+
x_val = np.random.random([4, 5]).astype(np.float32)
3143+
y_val = np.random.random([4, 5]).astype(np.float32)
3144+
def func(x, y):
3145+
return tf.math.less_equal(x, y, name=_TFOUTPUT)
3146+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
3147+
3148+
@check_opset_min_version(12)
3149+
def test_squared_distance(self):
3150+
x_val = np.random.random([4, 5]).astype(np.float32)
3151+
y_val = np.random.random([4, 5]).astype(np.float32)
3152+
def func(x, y):
3153+
return tf.math.squared_difference(x, y, name=_TFOUTPUT)
3154+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
3155+
31333156

31343157
if __name__ == '__main__':
31353158
unittest_main()

tf2onnx/onnx_opset/math.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
# pylint: disable=unused-argument,missing-docstring
2424

25-
@tf_op(["Add", "AddV2", "Div", "Mul", "Sub"])
25+
@tf_op(["Add", "AddV2", "Div", "Mul", "Sub", "LessOrEqual"])
2626
class BroadcastOp(common.BroadcastOp):
2727
pass
2828

@@ -544,3 +544,26 @@ def version_11(cls, ctx, node, **kwargs):
544544
cast_back_node.set_attr("to", dtypes[0])
545545
ctx.set_dtype(cast_back_node.output[0], dtypes[0])
546546
ctx.copy_shape(node.name, cast_back_node.output[0])
547+
548+
@tf_op("MatrixInverse")
549+
class Inverse:
550+
551+
@classmethod
552+
def version_12(cls, ctx, node, **kwargs):
553+
utils.make_sure(node.get_attr('adjoint').i == 0, "adjoint must be false")
554+
shapes = node.output_shapes
555+
dtypes = node.output_dtypes
556+
ctx.remove_node(node.name)
557+
ctx.make_node("Inverse", inputs=node.input, outputs=node.output, name=node.name,
558+
shapes=shapes, dtypes=dtypes)
559+
560+
@tf_op("SquaredDistance")
561+
class SquaredDistance:
562+
563+
@classmethod
564+
def version_12(cls, ctx, node, **kwargs):
565+
shapes = node.output_shapes
566+
dtypes = node.output_dtypes
567+
ctx.remove_node(node.name)
568+
ctx.make_node("MeanSquaredDistance", inputs=node.input, outputs=node.output, name=node.name,
569+
shapes=shapes, dtypes=dtypes, attr={"reduction": "none"})

0 commit comments

Comments
 (0)