3737
3838class AttrTestOp (CustomOp ):
3939 def get_nodeattr_types (self ):
40- return {"tensor_attr" : ("t" , True , np .asarray ([]))}
40+ my_attrs = {"tensor_attr" : ("t" , True , np .asarray ([])), "strings_attr" : ("strings" , True , ["" ])}
41+ return my_attrs
4142
4243 def make_shape_compatible_op (self , model ):
4344 param_tensor = self .get_nodeattr ("tensor_attr" )
@@ -70,6 +71,7 @@ def test_attr():
7071 strarr = np .array2string (w , separator = ", " )
7172 w_str = strarr .replace ("[" , "{" ).replace ("]" , "}" ).replace (" " , "" )
7273 tensor_attr_str = f"int8{ wshp_str } { w_str } "
74+ strings_attr = ["a" , "bc" , "def" ]
7375
7476 input = f"""
7577 <
@@ -86,9 +88,17 @@ def test_attr():
8688 model = oprs .parse_model (input )
8789 model = ModelWrapper (model )
8890 inst = getCustomOp (model .graph .node [0 ])
91+
8992 w_prod = inst .get_nodeattr ("tensor_attr" )
9093 assert (w_prod == w ).all ()
9194 w = w - 1
9295 inst .set_nodeattr ("tensor_attr" , w )
9396 w_prod = inst .get_nodeattr ("tensor_attr" )
9497 assert (w_prod == w ).all ()
98+
99+ inst .set_nodeattr ("strings_attr" , strings_attr )
100+ strings_attr_prod = inst .get_nodeattr ("strings_attr" )
101+ assert strings_attr_prod == strings_attr
102+ strings_attr_prod [0 ] = "test"
103+ inst .set_nodeattr ("strings_attr" , strings_attr_prod )
104+ assert inst .get_nodeattr ("strings_attr" ) == ["test" ] + strings_attr [1 :]
0 commit comments