Skip to content

Commit 04220ba

Browse files
committed
Add support for Keras resnet50 pretrained model testing.
1 parent f5e4ed3 commit 04220ba

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -482,8 +482,21 @@ def _mul_handler(self, trans, node):
482482
self._g.remove_node(node.name)
483483
return True
484484

485-
# if the shape is () or (1), we just move transpose after the mul
486-
if not multiplier.shape or (len(multiplier.shape) == 1 and multiplier.shape[0] == 1):
485+
# if the shape is (), we just move transpose after the mul
486+
if not multiplier.shape:
487+
return self._switch_transpose_and_node(node, trans)
488+
489+
# if multiplier is 1-D
490+
if len(multiplier.shape) == 1:
491+
if multiplier.shape[0] == 1:
492+
# shape is (1)
493+
return self._switch_transpose_and_node(node, trans)
494+
495+
# shape is (N). reshape so that trans(shape) = 1,1,...,N
496+
perm = list(trans.get_attr('perm').ints)
497+
new_shape = np.ones(len(perm), dtype=np.int32)
498+
new_shape[perm[-1]] = multiplier.shape[0]
499+
multiplier_input_node.set_tensor_value(multiplier.reshape(new_shape))
487500
return self._switch_transpose_and_node(node, trans)
488501

489502
return False

0 commit comments

Comments
 (0)