Skip to content

Commit eae4272

Browse files
committed
fix bug
dtype of slice's input should be same
1 parent acb92b1 commit eae4272

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

tf2onnx/onnx_opset/tensor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from __future__ import print_function
1010
from __future__ import unicode_literals
1111

12+
import sys
1213
import logging
1314

1415
import numpy as np
@@ -356,7 +357,7 @@ def make_gathernd(ctx, params, indices, output, scope_name, t_params, shapes, dt
356357
# reshape indices into [sum(indices[:-1]), indices[-1]]
357358
indices_shape = ctx.make_node("Shape", [indices], dtypes=[TensorProto.INT64])
358359
indices_size = ctx.make_node("Size", [indices])
359-
attr = {"axes": [0], "ends": [utils.get_max_value(np.int64)], "starts": [-1]}
360+
attr = {"axes": [0], "ends": [sys.maxsize], "starts": [-1]}
360361
inputs_map = {"data": indices_shape.output[0], **attr}
361362
inner_shape = GraphBuilder(ctx).make_slice(inputs_map, dtypes=[TensorProto.INT64])
362363
outter_shape = ctx.make_node("Div",
@@ -414,7 +415,7 @@ def make_gathernd(ctx, params, indices, output, scope_name, t_params, shapes, dt
414415
[inner_loop_shape.output[0], one_const.output[0]],
415416
attr={"axis": 0},
416417
dtypes=[TensorProto.INT64])
417-
attr = {"axes": [0], "ends": [utils.get_max_value(np.int64)], "starts": [1]}
418+
attr = {"axes": [0], "ends": [sys.maxsize], "starts": [1]}
418419
inputs_map = {"data": inner_loop_shape_.output[0], **attr}
419420
output_inner_shape = GraphBuilder(ctx).make_slice(inputs_map, dtypes=[TensorProto.INT64])
420421
attr = {"axes": [0], "ends": [-1], "starts": [0]}

0 commit comments

Comments
 (0)