Skip to content

Commit ad2db07

Browse files
committed
move Inverse to ms.py
1 parent db2c48a commit ad2db07

File tree

2 files changed

+8
-9
lines changed

2 files changed

+8
-9
lines changed

tf2onnx/custom_opsets/ms.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,11 @@ def version_11(cls, ctx, node, **kwargs):
103103
ctx.insert_new_node_on_input(node, "Transpose", node.input[0], perm=constants.NHWC_TO_NCHW)
104104
ctx.insert_new_node_on_output("Transpose", node.output[0], node.name + '_transposed',
105105
None, perm=constants.NCHW_TO_NHWC)
106+
107+
@tf_op("MatrixInverse", onnx_op="Inverse")
108+
class Inverse:
109+
@classmethod
110+
def version_12(cls, ctx, node, **kwargs):
111+
utils.make_sure(node.get_attr('adjoint').i == 0, "adjoint must be false")
112+
del node.attr["adjoint"]
113+
node.domain = constants.MICROSOFT_DOMAIN

tf2onnx/onnx_opset/math.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -546,15 +546,6 @@ def version_11(cls, ctx, node, **kwargs):
546546
ctx.copy_shape(node.name, cast_back_node.output[0])
547547

548548

549-
@tf_op("MatrixInverse", onnx_op="Inverse")
550-
class Inverse:
551-
@classmethod
552-
def version_12(cls, ctx, node, **kwargs):
553-
utils.make_sure(node.get_attr('adjoint').i == 0, "adjoint must be false")
554-
del node.attr["adjoint"]
555-
node.domain = constants.MICROSOFT_DOMAIN
556-
557-
558549
@tf_op("SquaredDistance", onnx_op="MeanSquaredDistance")
559550
class SquaredDistance:
560551
@classmethod

0 commit comments

Comments
 (0)