Skip to content

Commit d59146c

Browse files
committed
Update tf.fill() to use dynamic size instead of const.
Update MatrixBandPart to set
1 parent e8ca5a8 commit d59146c

File tree

1 file changed

+2
-12
lines changed

1 file changed

+2
-12
lines changed

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def _initialize_handlers(self):
174174
"Cast": self._simple_through_handler,
175175
"Clip": self._simple_through_handler,
176176
"Concat": self._concat_handler,
177+
"Elu": self._simple_through_handler,
177178
"Identity": self._identity_handler,
178179
"LeakyRelu": self._simple_through_handler,
179180
"Max": self._maxmin_handler,
@@ -247,14 +248,6 @@ def _switch_transpose_and_node(self, node, trans):
247248
shape = self._g.get_shape(node.output[0])
248249
if shape:
249250
# only nhwc transpose can reach here
250-
# if slicing results in non 4-D shape, we cannot infer the outputshape
251-
# JRP
252-
# new_shape = shape
253-
# if len(shape) == len(NHWC_TO_NCHW):
254-
# new_shape = [shape[i] for i in NHWC_TO_NCHW]
255-
# else:
256-
# new_shape = [-1]*len(shape)
257-
# self.logger.warning("%s's shape is unknown, which may interfere further optimization", node.output[0])
258251
new_shape = [shape[i] for i in NHWC_TO_NCHW]
259252
self._g.set_shape(node.output[0], new_shape)
260253
return True
@@ -526,7 +519,7 @@ def _calculate_new_attr(ori_perm, ori_squeeze_axes):
526519
squeeze_shape = self._g.get_shape(node.output[0])
527520
self._g.set_shape(trans.output[0], squeeze_shape)
528521
input_shape = self._g.get_shape(node.input[0])
529-
if input_shape is not None: # JRP and len(input_shape) == 4:
522+
if input_shape is not None:
530523
new_squeeze_output_shape = [input_shape[i] for i in range(4) if i not in new_squeeze_axes]
531524
else:
532525
new_squeeze_output_shape = [-1] * 4
@@ -576,9 +569,6 @@ def _slice_handler(self, trans, node):
576569
else: # in opset 10, axes is input instead of an attribute.
577570
if len(node.inputs) >= 4 and node.inputs[3].is_const():
578571
axes = node.inputs[3].get_tensor_value(as_list=True)
579-
# JRP
580-
#shape = self._g.get_shape(node.output[0])
581-
#if axes == [0, 1, 2, 3] and len(shape) == 4:
582572
if axes == [0, 1, 2, 3]:
583573
# axes node might be shared
584574
new_axes = np.array(NCHW_TO_NHWC, dtype=np.int64)

0 commit comments

Comments
 (0)