Skip to content

Commit 0f880fc

Browse files
authored
Merge pull request #521 from zhijxu-MS/bert_bug
fix bug found in "bert"
2 parents bc7b7ae + db14785 commit 0f880fc

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

tf2onnx/onnx_opset/tensor.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from __future__ import print_function
1010
from __future__ import unicode_literals
1111

12+
import sys
1213
import logging
1314

1415
import numpy as np
@@ -174,12 +175,10 @@ def version_4(cls, ctx, node, **kwargs):
174175
if perm.is_const():
175176
# perms is passed as const
176177
dims = perm.get_tensor_value()
178+
ctx.remove_input(node, node.input[1])
179+
node.set_attr("perm", dims)
177180
else:
178-
# calculate perms from shape
179-
shape = ctx.get_shape(node.input[1])
180-
dims = [i for i in range(len(shape) - 1, -1)]
181-
ctx.remove_input(node, node.input[1])
182-
node.set_attr("perm", dims)
181+
utils.make_sure(False, "perm can't be dynamic in ONNX")
183182
else:
184183
# graph rewrite moved perm to attribute
185184
pass
@@ -356,7 +355,7 @@ def make_gathernd(ctx, params, indices, output, scope_name, t_params, shapes, dt
356355
# reshape indices into [sum(indices[:-1]), indices[-1]]
357356
indices_shape = ctx.make_node("Shape", [indices], dtypes=[TensorProto.INT64])
358357
indices_size = ctx.make_node("Size", [indices])
359-
attr = {"axes": [0], "ends": [utils.get_max_value(np.int64)], "starts": [-1]}
358+
attr = {"axes": [0], "ends": [sys.maxsize], "starts": [-1]}
360359
inputs_map = {"data": indices_shape.output[0], **attr}
361360
inner_shape = GraphBuilder(ctx).make_slice(inputs_map, dtypes=[TensorProto.INT64])
362361
outter_shape = ctx.make_node("Div",
@@ -414,7 +413,7 @@ def make_gathernd(ctx, params, indices, output, scope_name, t_params, shapes, dt
414413
[inner_loop_shape.output[0], one_const.output[0]],
415414
attr={"axis": 0},
416415
dtypes=[TensorProto.INT64])
417-
attr = {"axes": [0], "ends": [utils.get_max_value(np.int64)], "starts": [1]}
416+
attr = {"axes": [0], "ends": [sys.maxsize], "starts": [1]}
418417
inputs_map = {"data": inner_loop_shape_.output[0], **attr}
419418
output_inner_shape = GraphBuilder(ctx).make_slice(inputs_map, dtypes=[TensorProto.INT64])
420419
attr = {"axes": [0], "ends": [-1], "starts": [0]}

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,16 +60,17 @@ def pre_optimize_action(self):
6060
target_t = reshape_op.inputs[0].get_tensor_value(as_list=False)
6161
target_shape = reshape_op.inputs[1].get_tensor_value(as_list=False)
6262
new_data = np.reshape(target_t, tuple(target_shape))
63-
const_name = utils.port_name(utils.make_name("Const"))
63+
const_name = reshape_op.output[0]
64+
self._g.remove_node(reshape_op.name)
65+
self._g.make_const(const_name, new_data)
6466

6567
# point all children nodes inputs to the new node
6668
for output_name in reshape_op.output:
6769
for child in ops:
6870
for i, name in enumerate(child.input):
6971
if name == output_name:
7072
child.input[i] = const_name
71-
self._g.make_const(const_name, new_data)
72-
self._g.remove_node(reshape_op.name)
73+
7374
self._g.topological_sort(self._g.get_nodes())
7475

7576
def post_optimize_action(self):

0 commit comments

Comments
 (0)