Skip to content

Commit 65ae24a

Browse files
committed
node get/set data_format via attr
1 parent ad365f6 commit 65ae24a

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

tf2onnx/graph.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,6 @@ def __init__(self, node, graph, skip_conversion=False):
4646
# dict to original attributes
4747
for a in node.attribute:
4848
self._attr[a.name] = a
49-
50-
self.data_format = self.get_attr("data_format")
51-
if self.data_format:
52-
self.data_format = self.data_format.s.decode("utf-8")
5349
self._skip_conversion = skip_conversion
5450

5551
@property
@@ -123,6 +119,16 @@ def domain(self, val):
123119
"""Set Op type."""
124120
self._op.domain = val
125121

122+
@property
123+
def data_format(self):
124+
"""Return data_format."""
125+
return self.get_attr_str("data_format")
126+
127+
@data_format.setter
128+
def data_format(self, val):
129+
"""Set data_format."""
130+
self.set_attr("data_format", val)
131+
126132
def is_nhwc(self):
127133
"""Return True if node is in NCHW format."""
128134
return self.data_format == "NHWC"
@@ -141,17 +147,22 @@ def __repr__(self):
141147
return "<onnx op type='%s' name=%s>" % (self.type, self._op.name)
142148

143149
def get_attr(self, name, default=None):
144-
"""Get attribute map."""
150+
"""Get raw attribute value."""
145151
attr = self.attr.get(name, default)
146152
return attr
147153

148154
def get_attr_int(self, name):
149-
"""Get attribute map."""
150-
attr = self.attr.get(name)
155+
"""Get attribute value as int."""
156+
attr = self.get_attr(name)
151157
utils.make_sure(attr is not None, "attribute %s is None", name)
152158
attr = attr.i
153159
return attr
154160

161+
def get_attr_str(self, name, encoding="utf-8"):
162+
"""Get attribute value as string."""
163+
attr = self.get_attr(name)
164+
return attr.s.decode(encoding) if attr else None
165+
155166
def set_attr(self, name, value):
156167
self.attr[name] = helper.make_attribute(name, value)
157168

0 commit comments

Comments
 (0)