@@ -108,6 +108,22 @@ def test_insert_node2(self):
108
108
'n5_raw_output___3:0 -> n6 n5_raw_output___3:0 -> n5_graph_outputs_Identity__4 }'
109
109
self .assertEqual (expected , result )
110
110
111
+ def test_make_const_string (self ):
112
+ graph_proto = self .sample_net ()
113
+ g = GraphUtil .create_graph_from_onnx_graph (graph_proto )
114
+ arr1 = np .array ("test" , np .object )
115
+ arr2 = np .array ([["A" , "B" ], ["C" , "D" ]], np .object )
116
+ arr3 = np .array (b"test" , np .object )
117
+ arr4 = np .array ([[b"A" , b"B" ], [b"C" , b"D" ]], np .object )
118
+ const1 = g .make_const ("const1" , arr1 )
119
+ const2 = g .make_const ("const2" , arr2 )
120
+ const3 = g .make_const ("const3" , arr3 )
121
+ const4 = g .make_const ("const4" , arr4 )
122
+ np .testing .assert_equal (const1 .get_tensor_value (False ), arr1 )
123
+ np .testing .assert_equal (const2 .get_tensor_value (False ), arr2 )
124
+ np .testing .assert_equal (const3 .get_tensor_value (False ), arr1 )
125
+ np .testing .assert_equal (const4 .get_tensor_value (False ), arr2 )
126
+
111
127
def test_remove_input (self ):
112
128
graph_proto = self .sample_net ()
113
129
g = GraphUtil .create_graph_from_onnx_graph (graph_proto )
0 commit comments