|
9 | 9 | from __future__ import print_function
|
10 | 10 | from __future__ import unicode_literals
|
11 | 11 |
|
| 12 | +import sys |
12 | 13 | import logging
|
13 | 14 |
|
14 | 15 | import numpy as np
|
@@ -174,12 +175,10 @@ def version_4(cls, ctx, node, **kwargs):
|
174 | 175 | if perm.is_const():
|
175 | 176 | # perms is passed as const
|
176 | 177 | dims = perm.get_tensor_value()
|
| 178 | + ctx.remove_input(node, node.input[1]) |
| 179 | + node.set_attr("perm", dims) |
177 | 180 | else:
|
178 |
| - # calculate perms from shape |
179 |
| - shape = ctx.get_shape(node.input[1]) |
180 |
| - dims = [i for i in range(len(shape) - 1, -1)] |
181 |
| - ctx.remove_input(node, node.input[1]) |
182 |
| - node.set_attr("perm", dims) |
| 181 | + utils.make_sure(False, "perm can't be dynamic in ONNX") |
183 | 182 | else:
|
184 | 183 | # graph rewrite moved perm to attribute
|
185 | 184 | pass
|
@@ -356,7 +355,7 @@ def make_gathernd(ctx, params, indices, output, scope_name, t_params, shapes, dt
|
356 | 355 | # reshape indices into [sum(indices[:-1]), indices[-1]]
|
357 | 356 | indices_shape = ctx.make_node("Shape", [indices], dtypes=[TensorProto.INT64])
|
358 | 357 | indices_size = ctx.make_node("Size", [indices])
|
359 |
| - attr = {"axes": [0], "ends": [utils.get_max_value(np.int64)], "starts": [-1]} |
| 358 | + attr = {"axes": [0], "ends": [sys.maxsize], "starts": [-1]} |
360 | 359 | inputs_map = {"data": indices_shape.output[0], **attr}
|
361 | 360 | inner_shape = GraphBuilder(ctx).make_slice(inputs_map, dtypes=[TensorProto.INT64])
|
362 | 361 | outter_shape = ctx.make_node("Div",
|
@@ -414,7 +413,7 @@ def make_gathernd(ctx, params, indices, output, scope_name, t_params, shapes, dt
|
414 | 413 | [inner_loop_shape.output[0], one_const.output[0]],
|
415 | 414 | attr={"axis": 0},
|
416 | 415 | dtypes=[TensorProto.INT64])
|
417 |
| - attr = {"axes": [0], "ends": [utils.get_max_value(np.int64)], "starts": [1]} |
| 416 | + attr = {"axes": [0], "ends": [sys.maxsize], "starts": [1]} |
418 | 417 | inputs_map = {"data": inner_loop_shape_.output[0], **attr}
|
419 | 418 | output_inner_shape = GraphBuilder(ctx).make_slice(inputs_map, dtypes=[TensorProto.INT64])
|
420 | 419 | attr = {"axes": [0], "ends": [-1], "starts": [0]}
|
|
0 commit comments