Skip to content

Commit 10b8fc7

Browse files
authored
Merge pull request #372 from nbcsm/data_format
node get/set data_format via attribute internally
2 parents 09e839c + 5248aef commit 10b8fc7

File tree

2 files changed

+33
-7
lines changed

2 files changed

+33
-7
lines changed

tests/test_internals.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,21 @@ def test_shape_utils(self):
203203
self.assertFalse(utils.are_shapes_equal(None, []))
204204
self.assertTrue(utils.are_shapes_equal([1, 2, 3], (1, 2, 3)))
205205

206+
def test_data_format(self):
207+
n1 = helper.make_node("Conv", ["X", "W"], ["Y"], name="n1", data_format="NHWC")
208+
graph_proto = helper.make_graph(
209+
nodes=[n1],
210+
name="test",
211+
inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 2]),
212+
helper.make_tensor_value_info("W", TensorProto.FLOAT, [2, 2])],
213+
outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2])],
214+
initializer=[]
215+
)
216+
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
217+
n = g.get_node_by_name("n1")
218+
self.assertEqual(n.data_format, "NHWC")
219+
self.assertTrue(n.is_nhwc())
220+
206221

207222
if __name__ == '__main__':
208223
unittest_main()

tf2onnx/graph.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,6 @@ def __init__(self, node, graph, skip_conversion=False):
4646
# dict to original attributes
4747
for a in node.attribute:
4848
self._attr[a.name] = a
49-
50-
self.data_format = self.get_attr("data_format")
51-
if self.data_format:
52-
self.data_format = self.data_format.s.decode("utf-8")
5349
self._skip_conversion = skip_conversion
5450

5551
@property
@@ -123,6 +119,16 @@ def domain(self, val):
123119
"""Set Op type."""
124120
self._op.domain = val
125121

122+
@property
123+
def data_format(self):
124+
"""Return data_format."""
125+
return self.get_attr_str("data_format")
126+
127+
@data_format.setter
128+
def data_format(self, val):
129+
"""Set data_format."""
130+
self.set_attr("data_format", val)
131+
126132
def is_nhwc(self):
127133
"""Return True if node is in NCHW format."""
128134
return self.data_format == "NHWC"
@@ -141,17 +147,22 @@ def __repr__(self):
141147
return "<onnx op type='%s' name=%s>" % (self.type, self._op.name)
142148

143149
def get_attr(self, name, default=None):
144-
"""Get attribute map."""
150+
"""Get raw attribute value."""
145151
attr = self.attr.get(name, default)
146152
return attr
147153

148154
def get_attr_int(self, name):
149-
"""Get attribute map."""
150-
attr = self.attr.get(name)
155+
"""Get attribute value as int."""
156+
attr = self.get_attr(name)
151157
utils.make_sure(attr is not None, "attribute %s is None", name)
152158
attr = attr.i
153159
return attr
154160

161+
def get_attr_str(self, name, encoding="utf-8"):
162+
"""Get attribute value as string."""
163+
attr = self.get_attr(name)
164+
return attr.s.decode(encoding) if attr else None
165+
155166
def set_attr(self, name, value):
156167
self.attr[name] = helper.make_attribute(name, value)
157168

0 commit comments

Comments
 (0)