Skip to content

Commit 9bd265a

Browse files
Merge pull request #866 from RandySheriffH/rashuai/ZeroLikeBool
Rashuai/zero like bool
2 parents 9e46394 + 321714d commit 9bd265a

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

tests/test_backend.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2410,16 +2410,16 @@ def func(x):
24102410

24112411
@check_opset_min_version(7, "fill")
24122412
def test_zeros_like(self):
2413-
input_val = np.random.random_sample([10, 20]).astype(np.float32)
2414-
def func(x):
2415-
res = tf.zeros_like(x)
2416-
return tf.identity(res, name=_TFOUTPUT)
2417-
self._run_test_case(func, [_OUTPUT], {_INPUT: input_val})
2413+
input_x = np.random.random_sample([10, 20]).astype(np.float32)
2414+
input_y = np.array([20, 10]).astype(np.int64)
24182415

2419-
def func(x):
2420-
res = tf.zeros_like(x, dtype=tf.int32)
2421-
return tf.identity(res, name=_TFOUTPUT)
2422-
self._run_test_case(func, [_OUTPUT], {_INPUT: input_val})
2416+
def func(x, y):
2417+
z = tf.reshape(x, y)
2418+
return tf.zeros_like(z, name=_TFOUTPUT)
2419+
2420+
self._run_test_case(func, [_OUTPUT], {_INPUT: input_x, _INPUT1: input_y})
2421+
self._run_test_case(func, [_OUTPUT], {_INPUT: input_x.astype(np.int32), _INPUT1: input_y})
2422+
self._run_test_case(func, [_OUTPUT], {_INPUT: input_x > 0.5, _INPUT1: input_y})
24232423

24242424
@check_opset_min_version(9, "is_nan")
24252425
def test_isnan(self):

tf2onnx/onnx_opset/generator.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -151,13 +151,13 @@ def version_7(cls, ctx, node, **kwargs):
151151
class ZerosLike:
152152
@classmethod
153153
def version_1(cls, ctx, node, **kwargs):
154-
# T output = ZerosLike(T x)
155-
# when params "dtype" used, tf will call another op "Fill" instead, so Cast is not needed here.
156-
input_dtype = ctx.get_dtype(node.input[0])
157-
node_name = utils.make_name("zero")
158-
const_zero = ctx.make_const(node_name, np.array(0).astype(utils.map_onnx_to_numpy_type(input_dtype)))
159154
shapes = node.output_shapes
160155
dtypes = node.output_dtypes
161156
ctx.remove_node(node.name)
162-
ctx.make_node(op_type="Mul", inputs=[node.input[0], const_zero.output[0]],
163-
name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
157+
casted_input = ctx.make_node("Cast", node.input, attr={'to': onnx_pb.TensorProto.INT64})
158+
const_zero = ctx.make_const(utils.make_name("zero"), np.array(0).astype(np.int64))
159+
mul_node = ctx.make_node('Mul', inputs=[casted_input.output[0], const_zero.output[0]])
160+
ctx.make_node("Cast", inputs=[mul_node.output[0]],
161+
attr={'to': dtypes[0]},
162+
name=node.name, outputs=node.output,
163+
shapes=shapes, dtypes=dtypes)

0 commit comments

Comments
 (0)