Skip to content

Commit 672ed43

Browse files
committed
Fix error msg
1 parent 7bada1d commit 672ed43

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,14 @@ def _switch_transpose_and_node(self, node, trans):
247247
shape = self._g.get_shape(node.output[0])
248248
if shape:
249249
# only nhwc transpose can reach here
250+
# if slicing results in non 4-D shape, we cannot infer the outputshape
251+
# JRP
252+
# new_shape = shape
253+
# if len(shape) == len(NHWC_TO_NCHW):
254+
# new_shape = [shape[i] for i in NHWC_TO_NCHW]
255+
# else:
256+
# new_shape = [-1]*len(shape)
257+
# self.logger.warning("%s's shape is unknown, which may interfere further optimization", node.output[0])
250258
new_shape = [shape[i] for i in NHWC_TO_NCHW]
251259
self._g.set_shape(node.output[0], new_shape)
252260
return True
@@ -518,7 +526,7 @@ def _calculate_new_attr(ori_perm, ori_squeeze_axes):
518526
squeeze_shape = self._g.get_shape(node.output[0])
519527
self._g.set_shape(trans.output[0], squeeze_shape)
520528
input_shape = self._g.get_shape(node.input[0])
521-
if input_shape is not None:
529+
if input_shape is not None: # JRP and len(input_shape) == 4:
522530
new_squeeze_output_shape = [input_shape[i] for i in range(4) if i not in new_squeeze_axes]
523531
else:
524532
new_squeeze_output_shape = [-1] * 4
@@ -568,6 +576,9 @@ def _slice_handler(self, trans, node):
568576
else: # in opset 10, axes is input instead of an attribute.
569577
if len(node.inputs) >= 4 and node.inputs[3].is_const():
570578
axes = node.inputs[3].get_tensor_value(as_list=True)
579+
# JRP
580+
#shape = self._g.get_shape(node.output[0])
581+
#if axes == [0, 1, 2, 3] and len(shape) == 4:
571582
if axes == [0, 1, 2, 3]:
572583
# axes node might be shared
573584
new_axes = np.array(NCHW_TO_NHWC, dtype=np.int64)

0 commit comments

Comments
 (0)