Skip to content

Commit f445505

Browse files
authored
Merge pull request #358 from mindest/dev_sign_op9
implement Sign for opset9
2 parents 52bf5e4 + 3e5afc0 commit f445505

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

tf2onnx/tfonnx.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,7 @@ def cast_op(ctx, node, name, args):
670670
return node
671671

672672

673-
def sign_op(ctx, node, name, args):
673+
def sign_op4(ctx, node, name, args):
674674
"""Sign op."""
675675
# T sign = Sign(T Input)
676676
nodes = []
@@ -690,6 +690,14 @@ def sign_op(ctx, node, name, args):
690690
return nodes
691691

692692

693+
def sign_op9(ctx, node, name, args):
694+
node_dtype = ctx.get_dtype(node.output[0])
695+
utils.make_sure(node_dtype, "Dtype of {} is None".format(node.name))
696+
if node_dtype in [onnx_pb.TensorProto.BOOL, onnx_pb.TensorProto.COMPLEX64, onnx_pb.TensorProto.COMPLEX128]:
697+
raise ValueError("dtype " + str(node_dtype) + " is not supported in onnx for now")
698+
return node
699+
700+
693701
def biasadd_op(ctx, node, name, args):
694702
# T output = BiasAdd(T value, T bias, @string data_format)
695703
# T output = BiasAddV1(T value, T bias)
@@ -1793,7 +1801,7 @@ def where_op(ctx, node, name, args):
17931801
"Pack": (pack_op, []),
17941802
"Unpack": (unpack_op, []),
17951803
"Erf": (erf_op, []),
1796-
"Sign": (sign_op, []),
1804+
"Sign": (sign_op4, []),
17971805
"ZerosLike": (zeroslike_op, []),
17981806
}
17991807

@@ -1868,6 +1876,7 @@ def where_op(ctx, node, name, args):
18681876
"Less": (logical_compare_op, []),
18691877
"ResizeBilinear": (upsample_op9, ["Upsample", "linear"]),
18701878
"ResizeNearestNeighbor": (upsample_op9, ["Upsample", "nearest"]),
1879+
"Sign": (sign_op9, []),
18711880
"Sinh": (direct_op, []),
18721881
"Where": (where_op, []),
18731882
}

0 commit comments

Comments
 (0)