@@ -46,10 +46,6 @@ def __init__(self, node, graph, skip_conversion=False):
46
46
# dict to original attributes
47
47
for a in node .attribute :
48
48
self ._attr [a .name ] = a
49
-
50
- self .data_format = self .get_attr ("data_format" )
51
- if self .data_format :
52
- self .data_format = self .data_format .s .decode ("utf-8" )
53
49
self ._skip_conversion = skip_conversion
54
50
55
51
@property
@@ -123,6 +119,16 @@ def domain(self, val):
123
119
"""Set Op type."""
124
120
self ._op .domain = val
125
121
122
+ @property
123
+ def data_format (self ):
124
+ """Return data_format."""
125
+ return self .get_attr_str ("data_format" )
126
+
127
+ @data_format .setter
128
+ def data_format (self , val ):
129
+ """Set data_format."""
130
+ self .set_attr ("data_format" , val )
131
+
126
132
def is_nhwc (self ):
127
133
"""Return True if node is in NCHW format."""
128
134
return self .data_format == "NHWC"
@@ -141,17 +147,22 @@ def __repr__(self):
141
147
return "<onnx op type='%s' name=%s>" % (self .type , self ._op .name )
142
148
143
149
def get_attr (self , name , default = None ):
144
- """Get attribute map ."""
150
+ """Get raw attribute value ."""
145
151
attr = self .attr .get (name , default )
146
152
return attr
147
153
148
154
def get_attr_int (self , name ):
149
- """Get attribute map ."""
150
- attr = self .attr . get (name )
155
+ """Get attribute value as int ."""
156
+ attr = self .get_attr (name )
151
157
utils .make_sure (attr is not None , "attribute %s is None" , name )
152
158
attr = attr .i
153
159
return attr
154
160
161
+ def get_attr_str (self , name , encoding = "utf-8" ):
162
+ """Get attribute value as string."""
163
+ attr = self .get_attr (name )
164
+ return attr .s .decode (encoding ) if attr else None
165
+
155
166
def set_attr (self , name , value ):
156
167
self .attr [name ] = helper .make_attribute (name , value )
157
168
0 commit comments