Skip to content

Commit 876f2b7

Browse files
committed
prevent over transpoing
1 parent 53c56da commit 876f2b7

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

tf2onnx/graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ def domain(self, val):
135135
@property
136136
def data_format(self):
137137
"""Return data_format."""
138-
return self.get_attr_str("data_format")
138+
attr_str = self.get_attr_value("data_format")
139+
return "unkown" if attr_str == None else attr_str.decode("utf-8")
139140

140141
@data_format.setter
141142
def data_format(self, val):

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -556,12 +556,14 @@ def _pad_handler(self, trans, node):
556556
new_pads = [pads[0], pads[3], pads[1], pads[2], pads[4], pads[7], pads[5], pads[6]]
557557
node.set_attr("pads", new_pads)
558558
return self._switch_transpose_and_node(node, trans)
559-
if node.inputs[1].is_const() and self._nodes_has_single_consumer_node([node.inputs[1]]):
560-
pads = node.inputs[1].get_tensor_value()
561-
# NHWC->NCHW
562-
new_pads = np.array([pads[0], pads[3], pads[1], pads[2], pads[4], pads[7], pads[5], pads[6]],
563-
dtype=np.int64)
564-
node.inputs[1].set_tensor_value(new_pads)
559+
if node.inputs[1].is_const():
560+
if node.inputs[1].data_format in ["NHWC", "unkown"]:
561+
pads = node.inputs[1].get_tensor_value()
562+
# NHWC->NCHW
563+
new_pads = np.array([pads[0], pads[3], pads[1], pads[2], pads[4], pads[7], pads[5], pads[6]],
564+
dtype=np.int64)
565+
node.inputs[1].set_tensor_value(new_pads)
566+
node.inputs[1].data_format = "NCHW"
565567
return self._switch_transpose_and_node(node, trans)
566568
return False
567569

0 commit comments

Comments
 (0)