Skip to content

Commit d3dd7f0

Browse files
authored
Merge pull request #907 from jignparm/jignparm/fix_transpose_pad
Fix Transpose + Pad handler, for Keras app MobilenetV2 model
2 parents 59fed17 + b9ba4e1 commit d3dd7f0

File tree

4 files changed

+30
-5
lines changed

4 files changed

+30
-5
lines changed

tests/run_pretrained_models.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,18 @@ keras_resnet50:
399399
model: ResNet50
400400
model_type: keras
401401
input_get: get_ramp
402+
inputs:
403+
"input_1:0": [1, 224, 224, 3]
404+
outputs:
405+
- Identity:0
406+
407+
keras_mobilenet_v2:
408+
tf_min_version: 2.1
409+
disabled: false
410+
url: module://tensorflow.keras.applications.mobilenet_v2/MobileNetV2
411+
model: MobileNetV2
412+
model_type: keras
413+
input_get: get_ramp
402414
inputs:
403415
"input_1:0": [1, 224, 224, 3]
404416
outputs:

tf2onnx/graph.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,13 @@ def make_const(self, name, np_val, skip_conversion=False, raw=True):
504504
self.set_dtype(name, utils.map_numpy_to_onnx_dtype(np_val.dtype))
505505
return node
506506

507+
def copy_const(self, node, name=None):
508+
"""Copy a const node, using name if specified"""
509+
# TODO: support attr copy starting at opset 12
510+
if name is None:
511+
name = utils.make_name(node.name)
512+
return self.make_const(name, node.get_tensor_value(as_list=False))
513+
507514
def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, skip_conversion=True,
508515
op_name_scope=None, name=None, shapes=None, dtypes=None, domain=constants.ONNX_DOMAIN,
509516
infer_shape_dtype=True):

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -602,14 +602,19 @@ def _pad_handler(self, trans, node):
602602
new_pads = [pads[0], pads[3], pads[1], pads[2], pads[4], pads[7], pads[5], pads[6]]
603603
node.set_attr("pads", new_pads)
604604
return self._switch_transpose_and_node(node, trans)
605-
if node.inputs[1].is_const():
606-
if node.inputs[1].data_format in ["NHWC", "unkown"]:
607-
pads = node.inputs[1].get_tensor_value()
605+
606+
input1 = node.inputs[1]
607+
if input1.is_const():
608+
if input1.data_format in ["NHWC", "unkown"]:
609+
if not self._nodes_has_single_consumer_node([input1]):
610+
input1 = self._g.copy_const(input1)
611+
node.input[1] = input1.output[0]
612+
pads = input1.get_tensor_value()
608613
# NHWC->NCHW
609614
new_pads = np.array([pads[0], pads[3], pads[1], pads[2], pads[4], pads[7], pads[5], pads[6]],
610615
dtype=np.int64)
611-
node.inputs[1].set_tensor_value(new_pads)
612-
node.inputs[1].data_format = "NCHW"
616+
input1.set_tensor_value(new_pads)
617+
input1.data_format = "NCHW"
613618
return self._switch_transpose_and_node(node, trans)
614619
return False
615620

tf2onnx/tf_loader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ def from_keras(model_path, input_names, output_names):
260260
# Handles Keras when Eager mode is enabled.
261261
custom_objects = None
262262
if context.executing_eagerly():
263+
_keras.backend.clear_session()
263264
_keras.backend.set_learning_phase(False)
264265
keras_model = _keras.models.load_model(model_path, custom_objects)
265266

0 commit comments

Comments
 (0)