@@ -54,11 +54,6 @@ def test_modelwrapper():
5454 assert first_conv_iname != "" and (first_conv_iname is not None )
5555 assert first_conv_wname != "" and (first_conv_wname is not None )
5656 assert first_conv_oname != "" and (first_conv_oname is not None )
57- first_conv_weights = model .get_initializer (first_conv_wname )
58- assert first_conv_weights .shape == (8 , 1 , 5 , 5 )
59- first_conv_weights_rand = np .random .randn (8 , 1 , 5 , 5 )
60- model .set_initializer (first_conv_wname , first_conv_weights_rand )
61- assert (model .get_initializer (first_conv_wname ) == first_conv_weights_rand ).all ()
6257 inp_cons = model .find_consumer (first_conv_iname )
6358 assert inp_cons == first_conv
6459 out_prod = model .find_producer (first_conv_oname )
@@ -75,6 +70,21 @@ def test_modelwrapper():
7570 assert model .get_tensor_sparsity (first_conv_iname ) == inp_sparsity
7671
7772
73+ def test_modelwrapper_set_get_rm_initializer ():
74+ raw_m = get_data ("qonnx.data" , "onnx/mnist-conv/model.onnx" )
75+ model = ModelWrapper (raw_m )
76+ conv_nodes = model .get_nodes_by_op_type ("Conv" )
77+ first_conv = conv_nodes [0 ]
78+ first_conv_wname = first_conv .input [1 ]
79+ first_conv_weights = model .get_initializer (first_conv_wname )
80+ assert first_conv_weights .shape == (8 , 1 , 5 , 5 )
81+ first_conv_weights_rand = np .random .randn (8 , 1 , 5 , 5 )
82+ model .set_initializer (first_conv_wname , first_conv_weights_rand )
83+ assert (model .get_initializer (first_conv_wname ) == first_conv_weights_rand ).all ()
84+ model .del_initializer (first_conv_wname )
85+ assert model .get_initializer (first_conv_wname ) is None
86+
87+
7888def test_modelwrapper_graph_order ():
7989 # create small network with properties to be tested
8090 Neg_node = onnx .helper .make_node ("Neg" , inputs = ["in1" ], outputs = ["neg1" ])
0 commit comments