Skip to content

Commit 9ace355

Browse files
committed
use onnx-op instead of removeing-adding node
1 parent e498134 commit 9ace355

File tree

1 file changed

+9
-15
lines changed

1 file changed

+9
-15
lines changed

tf2onnx/onnx_opset/math.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
# pylint: disable=unused-argument,missing-docstring
2424

25-
@tf_op(["Add", "AddV2", "Div", "Mul", "Sub", "LessOrEqual"])
25+
@tf_op(["Add", "AddV2", "Div", "Mul", "Sub"])
2626
class BroadcastOp(common.BroadcastOp):
2727
pass
2828

@@ -545,25 +545,19 @@ def version_11(cls, ctx, node, **kwargs):
545545
ctx.set_dtype(cast_back_node.output[0], dtypes[0])
546546
ctx.copy_shape(node.name, cast_back_node.output[0])
547547

548-
@tf_op("MatrixInverse")
549-
class Inverse:
550548

549+
@tf_op("MatrixInverse", onnx_op="Inverse")
550+
class Inverse:
551551
@classmethod
552552
def version_12(cls, ctx, node, **kwargs):
553553
utils.make_sure(node.get_attr('adjoint').i == 0, "adjoint must be false")
554-
shapes = node.output_shapes
555-
dtypes = node.output_dtypes
556-
ctx.remove_node(node.name)
557-
ctx.make_node("Inverse", inputs=node.input, outputs=node.output, name=node.name,
558-
domain=constants.MICROSOFT_DOMAIN, shapes=shapes, dtypes=dtypes)
554+
del node.attr["adjoint"]
555+
node.domain = constants.MICROSOFT_DOMAIN
559556

560-
@tf_op("SquaredDistance")
561-
class SquaredDistance:
562557

558+
@tf_op("SquaredDistance", onnx_op="MeanSquaredDistance")
559+
class SquaredDistance:
563560
@classmethod
564561
def version_12(cls, ctx, node, **kwargs):
565-
shapes = node.output_shapes
566-
dtypes = node.output_dtypes
567-
ctx.remove_node(node.name)
568-
ctx.make_node("MeanSquaredDistance", inputs=node.input, outputs=node.output, name=node.name,
569-
shapes=shapes, dtypes=dtypes, attr={"reduction": "none"})
562+
node.attr["reduction"] = "none"
563+

0 commit comments

Comments
 (0)