Skip to content

Commit 01c403b

Browse files
committed
Fix Transpose + Pad handler, for Keras app MobilenetV2 model
1 parent 59fed17 commit 01c403b

File tree

3 files changed

+29
-5
lines changed

3 files changed

+29
-5
lines changed

tests/run_pretrained_models.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,18 @@ keras_resnet50:
399399
model: ResNet50
400400
model_type: keras
401401
input_get: get_ramp
402+
inputs:
403+
"input_1:0": [1, 224, 224, 3]
404+
outputs:
405+
- Identity:0
406+
407+
keras_mobilenet_v2:
408+
tf_min_version: 2.1
409+
disabled: false
410+
url: module://tensorflow.keras.applications.mobilenet_v2/MobileNetV2
411+
model: MobileNetV2
412+
model_type: keras
413+
input_get: get_ramp
402414
inputs:
403415
"input_1:0": [1, 224, 224, 3]
404416
outputs:

tf2onnx/graph.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,13 @@ def make_const(self, name, np_val, skip_conversion=False, raw=True):
504504
self.set_dtype(name, utils.map_numpy_to_onnx_dtype(np_val.dtype))
505505
return node
506506

507+
def copy_const(self, node, name=None):
508+
"""Copy a const node, using name is specified"""
509+
# TODO: support attr copy starting opset 12
510+
if name is None:
511+
name = utils.make_name(node.name)
512+
return self.make_const(name, node.get_tensor_value(as_list=False))
513+
507514
def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, skip_conversion=True,
508515
op_name_scope=None, name=None, shapes=None, dtypes=None, domain=constants.ONNX_DOMAIN,
509516
infer_shape_dtype=True):

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -602,14 +602,19 @@ def _pad_handler(self, trans, node):
602602
new_pads = [pads[0], pads[3], pads[1], pads[2], pads[4], pads[7], pads[5], pads[6]]
603603
node.set_attr("pads", new_pads)
604604
return self._switch_transpose_and_node(node, trans)
605-
if node.inputs[1].is_const():
606-
if node.inputs[1].data_format in ["NHWC", "unkown"]:
607-
pads = node.inputs[1].get_tensor_value()
605+
606+
input1 = node.inputs[1]
607+
if input1.is_const():
608+
if input1.data_format in ["NHWC", "unkown"]:
609+
if not self._nodes_has_single_consumer_node([input1]):
610+
input1 = self.g.copy_const(input1)
611+
self.input[1] = input1.output[0]
612+
pads = input1.get_tensor_value()
608613
# NHWC->NCHW
609614
new_pads = np.array([pads[0], pads[3], pads[1], pads[2], pads[4], pads[7], pads[5], pads[6]],
610615
dtype=np.int64)
611-
node.inputs[1].set_tensor_value(new_pads)
612-
node.inputs[1].data_format = "NCHW"
616+
input1.set_tensor_value(new_pads)
617+
input1.data_format = "NCHW"
613618
return self._switch_transpose_and_node(node, trans)
614619
return False
615620

0 commit comments

Comments
 (0)