|
15 | 15 | from onnx import onnx_pb, numpy_helper
|
16 | 16 | from tf2onnx import utils
|
17 | 17 | from tf2onnx.handler import tf_op
|
18 |
| -from onnx.onnx_pb import TensorProto |
19 | 18 |
|
20 | 19 | logger = logging.getLogger(__name__)
|
21 | 20 |
|
@@ -155,10 +154,10 @@ def version_1(cls, ctx, node, **kwargs):
|
155 | 154 | shapes = node.output_shapes
|
156 | 155 | dtypes = node.output_dtypes
|
157 | 156 | ctx.remove_node(node.name)
|
158 |
| - casted_input = ctx.make_node("Cast", node.input, attr={'to': TensorProto.INT64}) |
| 157 | + casted_input = ctx.make_node("Cast", node.input, attr={'to': onnx_pb.TensorProto.INT64}) |
159 | 158 | const_zero = ctx.make_const(utils.make_name("zero"), np.array(0).astype(np.int64))
|
160 | 159 | mul_node = ctx.make_node('Mul', inputs=[casted_input.output[0], const_zero.output[0]])
|
161 |
| - casted_output = ctx.make_node("Cast", inputs=[mul_node.output[0]], |
162 |
| - attr={'to': dtypes[0]}, |
163 |
| - name=node.name, outputs=node.output, |
164 |
| - shapes=shapes, dtypes=dtypes) |
| 160 | + ctx.make_node("Cast", inputs=[mul_node.output[0]], |
| 161 | + attr={'to': dtypes[0]}, |
| 162 | + name=node.name, outputs=node.output, |
| 163 | + shapes=shapes, dtypes=dtypes) |
0 commit comments