Skip to content

Commit 40dd461

Browse files
Merge pull request #897 from RandySheriffH/rashuai/opset12
opset 12 support
2 parents 8e41270 + fa18093 commit 40dd461

File tree

5 files changed

+72
-4
lines changed

5 files changed

+72
-4
lines changed

ci_build/azure_pipelines/pylint.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ jobs:
88
steps:
99
- bash: |
1010
set -ex
11-
pip install pylint
11+
pip install pylint==2.4.4
1212
pip freeze
1313
pylint --rcfile=tools/pylintrc --ignore=version.py --disable=cyclic-import tf2onnx tests/*.py tools -j 0
1414
displayName: 'Pylint'

tests/test_backend.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3130,6 +3130,40 @@ def func(X, K):
31303130
k_val = np.array(raw_k).astype(np.int32)
31313131
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: x_val, _INPUT1: k_val})
31323132

3133+
@check_opset_min_version(12)
3134+
def test_inverse(self):
3135+
x_val = np.random.random([5, 5]).astype(np.float32)
3136+
def func(x):
3137+
return tf.linalg.inv(x, name=_TFOUTPUT)
3138+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
3139+
3140+
@check_opset_min_version(12)
3141+
def test_squared_distance(self):
3142+
x_val = np.random.random([4, 5]).astype(np.float32)
3143+
y_val = np.random.random([4, 5]).astype(np.float32)
3144+
def func(x, y):
3145+
return tf.math.squared_difference(x, y, name=_TFOUTPUT)
3146+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
3147+
3148+
@check_opset_min_version(12)
3149+
@check_tf_min_version("2.1")
3150+
def test_einsum(self):
3151+
x_val = np.random.random([10]).astype(np.float32)
3152+
y_val = np.random.random([10]).astype(np.float32)
3153+
def func(x, y):
3154+
ret = tf.einsum("i,j->ij", x, y)
3155+
return tf.identity(ret, name=_TFOUTPUT)
3156+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
3157+
3158+
@check_opset_min_version(7)
3159+
def test_compare(self):
3160+
x_val = np.random.random([10, 20]).astype(np.float32)
3161+
y_val = np.random.random([10, 20]).astype(np.float32)
3162+
def func(x, y):
3163+
return tf.math.less_equal(x, y, name=_TFOUTPUT), \
3164+
tf.math.greater_equal(x, y, name=_TFOUTPUT1)
3165+
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: x_val, _INPUT1: y_val})
3166+
31333167

31343168
if __name__ == '__main__':
31353169
unittest_main()

tf2onnx/custom_opsets/ms.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,11 @@ def version_11(cls, ctx, node, **kwargs):
103103
ctx.insert_new_node_on_input(node, "Transpose", node.input[0], perm=constants.NHWC_TO_NCHW)
104104
ctx.insert_new_node_on_output("Transpose", node.output[0], node.name + '_transposed',
105105
None, perm=constants.NCHW_TO_NHWC)
106+
107+
@tf_op("MatrixInverse", domain=constants.MICROSOFT_DOMAIN, onnx_op="Inverse")
108+
class Inverse:
109+
@classmethod
110+
def version_12(cls, ctx, node, **kwargs):
111+
utils.make_sure(node.get_attr('adjoint').i == 0, "adjoint must be false")
112+
del node.attr["adjoint"]
113+
node.domain = constants.MICROSOFT_DOMAIN

tf2onnx/onnx_opset/logical.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,14 +110,17 @@ def version_7(cls, ctx, node, **kwargs):
110110
target_dtype = TensorProto.FLOAT
111111
_add_cast_to_inputs(ctx, node, supported_dtypes, target_dtype)
112112

113-
114-
@tf_op("GreaterEqual", onnx_op="Less")
115-
@tf_op("LessEqual", onnx_op="Greater")
113+
@tf_op(["GreaterEqual", "LessEqual"])
116114
class GreaterLessEqual:
117115
@classmethod
118116
def version_7(cls, ctx, node, **kwargs):
119117
GreaterLess.version_7(ctx, node, **kwargs)
120118
output_name = node.output[0]
119+
node.op.op_type = "Less" if node.op.op_type == "GreaterEqual" else "Greater"
121120
new_node = ctx.insert_new_node_on_output("Not", output_name, name=utils.make_name(node.name))
122121
ctx.copy_shape(output_name, new_node.output[0])
123122
ctx.set_dtype(new_node.output[0], ctx.get_dtype(output_name))
123+
124+
@classmethod
125+
def version_12(cls, ctx, node, **kwargs):
126+
node.op.op_type = "GreaterOrEqual" if node.op.op_type == "GreaterEqual" else "LessOrEqual"

tf2onnx/onnx_opset/math.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,3 +544,26 @@ def version_11(cls, ctx, node, **kwargs):
544544
cast_back_node.set_attr("to", dtypes[0])
545545
ctx.set_dtype(cast_back_node.output[0], dtypes[0])
546546
ctx.copy_shape(node.name, cast_back_node.output[0])
547+
548+
549+
@tf_op("SquaredDistance", onnx_op="MeanSquaredDistance")
550+
class SquaredDistance:
551+
@classmethod
552+
def version_12(cls, ctx, node, **kwargs):
553+
node.attr["reduction"] = "none"
554+
555+
556+
@tf_op("Einsum")
557+
class Einsum:
558+
@classmethod
559+
def version_12(cls, ctx, node, **kwargs):
560+
del node.attr["N"]
561+
562+
563+
@tf_op("MatrixInverse", onnx_op="Inverse")
564+
class Inverse:
565+
@classmethod
566+
def version_12(cls, ctx, node, **kwargs):
567+
utils.make_sure(node.get_attr('adjoint').i == 0, "adjoint must be false")
568+
del node.attr["adjoint"]
569+
node.domain = constants.MICROSOFT_DOMAIN

0 commit comments

Comments
 (0)