@@ -151,13 +151,13 @@ def version_7(cls, ctx, node, **kwargs):
151
151
class ZerosLike :
152
152
@classmethod
153
153
def version_1 (cls , ctx , node , ** kwargs ):
154
- # T output = ZerosLike(T x)
155
- # when params "dtype" used, tf will call another op "Fill" instead, so Cast is not needed here.
156
- input_dtype = ctx .get_dtype (node .input [0 ])
157
- node_name = utils .make_name ("zero" )
158
- const_zero = ctx .make_const (node_name , np .array (0 ).astype (utils .map_onnx_to_numpy_type (input_dtype )))
159
154
shapes = node .output_shapes
160
155
dtypes = node .output_dtypes
161
156
ctx .remove_node (node .name )
162
- ctx .make_node (op_type = "Mul" , inputs = [node .input [0 ], const_zero .output [0 ]],
163
- name = node .name , outputs = node .output , shapes = shapes , dtypes = dtypes )
157
+ casted_input = ctx .make_node ("Cast" , node .input , attr = {'to' : onnx_pb .TensorProto .INT64 })
158
+ const_zero = ctx .make_const (utils .make_name ("zero" ), np .array (0 ).astype (np .int64 ))
159
+ mul_node = ctx .make_node ('Mul' , inputs = [casted_input .output [0 ], const_zero .output [0 ]])
160
+ ctx .make_node ("Cast" , inputs = [mul_node .output [0 ]],
161
+ attr = {'to' : dtypes [0 ]},
162
+ name = node .name , outputs = node .output ,
163
+ shapes = shapes , dtypes = dtypes )
0 commit comments