Skip to content

Commit cecc171

Browse files
committed
add test_node_attr_onnx
1 parent bf33b6a commit cecc171

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

tests/test_internals.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,36 @@ def test_data_format(self):
218218
self.assertEqual(n.data_format, "NHWC")
219219
self.assertTrue(n.is_nhwc())
220220

221+
def test_node_attr_onnx(self):
222+
n1 = helper.make_node("Conv", ["X", "W"], ["Y"], name="n1", my_attr="my_attr")
223+
graph_proto = helper.make_graph(
224+
nodes=[n1],
225+
name="test",
226+
inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 2]),
227+
helper.make_tensor_value_info("W", TensorProto.FLOAT, [2, 2])],
228+
outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2])],
229+
initializer=[]
230+
)
231+
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
232+
n1 = g.get_node_by_name("n1")
233+
self.assertTrue("my_attr" in n1.attr)
234+
self.assertTrue("my_attr" not in n1.attr_onnx)
235+
236+
n1 = helper.make_node("Conv", ["X", "W"], ["Y"], name="n1", domain="my_domain", my_attr="my_attr")
237+
print(n1)
238+
graph_proto = helper.make_graph(
239+
nodes=[n1],
240+
name="test",
241+
inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 2]),
242+
helper.make_tensor_value_info("W", TensorProto.FLOAT, [2, 2])],
243+
outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2])],
244+
initializer=[]
245+
)
246+
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
247+
n1 = g.get_node_by_name("n1")
248+
self.assertTrue("my_attr" in n1.attr)
249+
self.assertTrue("my_attr" in n1.attr_onnx)
250+
221251

222252
if __name__ == '__main__':
223253
unittest_main()

0 commit comments

Comments
 (0)