Skip to content

Commit 7118977

Browse files
authored
Merge pull request #854 from RandySheriffH/rashuai/FixPadConv
prevent multi-transpose on pad
2 parents 5f3dd1f + 7df24c2 commit 7118977

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

tf2onnx/graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ def domain(self, val):
135135
@property
136136
def data_format(self):
137137
"""Return data_format."""
138-
return self.get_attr_str("data_format")
138+
attr_str = self.get_attr_value("data_format")
139+
return "unkown" if attr_str is None else attr_str.decode("utf-8")
139140

140141
@data_format.setter
141142
def data_format(self, val):

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -557,11 +557,13 @@ def _pad_handler(self, trans, node):
557557
node.set_attr("pads", new_pads)
558558
return self._switch_transpose_and_node(node, trans)
559559
if node.inputs[1].is_const():
560-
pads = node.inputs[1].get_tensor_value()
561-
# NHWC->NCHW
562-
new_pads = np.array([pads[0], pads[3], pads[1], pads[2], pads[4], pads[7], pads[5], pads[6]],
563-
dtype=np.int64)
564-
node.inputs[1].set_tensor_value(new_pads)
560+
if node.inputs[1].data_format in ["NHWC", "unkown"]:
561+
pads = node.inputs[1].get_tensor_value()
562+
# NHWC->NCHW
563+
new_pads = np.array([pads[0], pads[3], pads[1], pads[2], pads[4], pads[7], pads[5], pads[6]],
564+
dtype=np.int64)
565+
node.inputs[1].set_tensor_value(new_pads)
566+
node.inputs[1].data_format = "NCHW"
565567
return self._switch_transpose_and_node(node, trans)
566568
return False
567569

0 commit comments

Comments
 (0)