15
15
from onnx import onnx_pb , numpy_helper
16
16
from tf2onnx import utils
17
17
from tf2onnx .handler import tf_op
18
+ from onnx .onnx_pb import TensorProto
18
19
19
20
logger = logging .getLogger (__name__ )
20
21
@@ -151,13 +152,13 @@ def version_7(cls, ctx, node, **kwargs):
151
152
class ZerosLike :
152
153
@classmethod
153
154
def version_1 (cls , ctx , node , ** kwargs ):
154
- # T output = ZerosLike(T x)
155
- # when params "dtype" used, tf will call another op "Fill" instead, so Cast is not needed here.
156
- input_dtype = ctx .get_dtype (node .input [0 ])
157
- node_name = utils .make_name ("zero" )
158
- const_zero = ctx .make_const (node_name , np .array (0 ).astype (utils .map_onnx_to_numpy_type (input_dtype )))
159
155
shapes = node .output_shapes
160
156
dtypes = node .output_dtypes
161
157
ctx .remove_node (node .name )
162
- ctx .make_node (op_type = "Mul" , inputs = [node .input [0 ], const_zero .output [0 ]],
163
- name = node .name , outputs = node .output , shapes = shapes , dtypes = dtypes )
158
+ casted_input = ctx .make_node ("Cast" , node .input , attr = {'to' : TensorProto .INT64 })
159
+ const_zero = ctx .make_const (utils .make_name ("zero" ), np .array (0 ).astype (np .int64 ))
160
+ mul_node = ctx .make_node ('Mul' , inputs = [casted_input .output [0 ], const_zero .output [0 ]])
161
+ casted_output = ctx .make_node ("Cast" , inputs = [mul_node .output [0 ]],
162
+ attr = {'to' : dtypes [0 ]},
163
+ name = node .name , outputs = node .output ,
164
+ shapes = shapes , dtypes = dtypes )
0 commit comments