3737
3838class AttrTestOp (CustomOp ):
3939 def get_nodeattr_types (self ):
40- return {"tensor_attr" : ("t" , True , np .asarray ([]))}
40+ my_attrs = {
41+ "tensor_attr" : ("t" , True , np .asarray ([])),
42+ "strings_attr" : ("strings" , True , ["" ])
43+ }
44+ return my_attrs
4145
4246 def make_shape_compatible_op (self , model ):
4347 param_tensor = self .get_nodeattr ("tensor_attr" )
@@ -70,6 +74,7 @@ def test_attr():
7074 strarr = np .array2string (w , separator = ", " )
7175 w_str = strarr .replace ("[" , "{" ).replace ("]" , "}" ).replace (" " , "" )
7276 tensor_attr_str = f"int8{ wshp_str } { w_str } "
77+ strings_attr = ["a" , "bc" , "def" ]
7378
7479 input = f"""
7580 <
@@ -86,9 +91,18 @@ def test_attr():
8691 model = oprs .parse_model (input )
8792 model = ModelWrapper (model )
8893 inst = getCustomOp (model .graph .node [0 ])
94+
8995 w_prod = inst .get_nodeattr ("tensor_attr" )
9096 assert (w_prod == w ).all ()
9197 w = w - 1
9298 inst .set_nodeattr ("tensor_attr" , w )
9399 w_prod = inst .get_nodeattr ("tensor_attr" )
94100 assert (w_prod == w ).all ()
101+
102+ inst .set_nodeattr ("strings_attr" , strings_attr )
103+ strings_attr_prod = inst .get_nodeattr ("strings_attr" )
104+ assert strings_attr_prod == strings_attr
105+ strings_attr_prod [0 ] = "test"
106+ inst .set_nodeattr ("strings_attr" , strings_attr_prod )
107+ assert inst .get_nodeattr ("strings_attr" ) == ["test" ] + strings_attr [1 :]
108+
0 commit comments