@@ -218,6 +218,36 @@ def test_data_format(self):
218
218
self .assertEqual (n .data_format , "NHWC" )
219
219
self .assertTrue (n .is_nhwc ())
220
220
221
+ def test_node_attr_onnx (self ):
222
+ n1 = helper .make_node ("Conv" , ["X" , "W" ], ["Y" ], name = "n1" , my_attr = "my_attr" )
223
+ graph_proto = helper .make_graph (
224
+ nodes = [n1 ],
225
+ name = "test" ,
226
+ inputs = [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , [2 , 2 ]),
227
+ helper .make_tensor_value_info ("W" , TensorProto .FLOAT , [2 , 2 ])],
228
+ outputs = [helper .make_tensor_value_info ("Y" , TensorProto .FLOAT , [2 , 2 ])],
229
+ initializer = []
230
+ )
231
+ g = GraphUtil .create_graph_from_onnx_graph (graph_proto )
232
+ n1 = g .get_node_by_name ("n1" )
233
+ self .assertTrue ("my_attr" in n1 .attr )
234
+ self .assertTrue ("my_attr" not in n1 .attr_onnx )
235
+
236
+ n1 = helper .make_node ("Conv" , ["X" , "W" ], ["Y" ], name = "n1" , domain = "my_domain" , my_attr = "my_attr" )
237
+ print (n1 )
238
+ graph_proto = helper .make_graph (
239
+ nodes = [n1 ],
240
+ name = "test" ,
241
+ inputs = [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , [2 , 2 ]),
242
+ helper .make_tensor_value_info ("W" , TensorProto .FLOAT , [2 , 2 ])],
243
+ outputs = [helper .make_tensor_value_info ("Y" , TensorProto .FLOAT , [2 , 2 ])],
244
+ initializer = []
245
+ )
246
+ g = GraphUtil .create_graph_from_onnx_graph (graph_proto )
247
+ n1 = g .get_node_by_name ("n1" )
248
+ self .assertTrue ("my_attr" in n1 .attr )
249
+ self .assertTrue ("my_attr" in n1 .attr_onnx )
250
+
221
251
222
252
if __name__ == '__main__' :
223
253
unittest_main ()
0 commit comments