Skip to content

Commit 6bf60d1

Browse files
committed
support zerolike of bool
1 parent b39da4b commit 6bf60d1

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

tf2onnx/onnx_opset/generator.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from onnx import onnx_pb, numpy_helper
1616
from tf2onnx import utils
1717
from tf2onnx.handler import tf_op
18+
from onnx.onnx_pb import TensorProto
1819

1920
logger = logging.getLogger(__name__)
2021

@@ -151,13 +152,13 @@ def version_7(cls, ctx, node, **kwargs):
151152
class ZerosLike:
152153
@classmethod
153154
def version_1(cls, ctx, node, **kwargs):
154-
# T output = ZerosLike(T x)
155-
# when params "dtype" used, tf will call another op "Fill" instead, so Cast is not needed here.
156-
input_dtype = ctx.get_dtype(node.input[0])
157-
node_name = utils.make_name("zero")
158-
const_zero = ctx.make_const(node_name, np.array(0).astype(utils.map_onnx_to_numpy_type(input_dtype)))
159155
shapes = node.output_shapes
160156
dtypes = node.output_dtypes
161157
ctx.remove_node(node.name)
162-
ctx.make_node(op_type="Mul", inputs=[node.input[0], const_zero.output[0]],
163-
name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
158+
casted_input = ctx.make_node("Cast", node.input, attr={'to': TensorProto.INT64})
159+
const_zero = ctx.make_const(utils.make_name("zero"), np.array(0).astype(np.int64))
160+
mul_node = ctx.make_node('Mul', inputs=[casted_input.output[0], const_zero.output[0]])
161+
casted_output = ctx.make_node("Cast", inputs=[mul_node.output[0]],
162+
attr={'to': dtypes[0]},
163+
name=node.name, outputs=node.output,
164+
shapes=shapes, dtypes=dtypes)

0 commit comments

Comments
 (0)