Skip to content

Commit ba5c41f

Browse files
author
mdaniowi
committed
strings attr test added to test_attr.py
1 parent 5414416 commit ba5c41f

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

tests/custom_op/test_attr.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,11 @@
3737

3838
class AttrTestOp(CustomOp):
3939
def get_nodeattr_types(self):
40-
return {"tensor_attr": ("t", True, np.asarray([]))}
40+
my_attrs = {
41+
"tensor_attr": ("t", True, np.asarray([])),
42+
"strings_attr": ("strings", True, [""])
43+
}
44+
return my_attrs
4145

4246
def make_shape_compatible_op(self, model):
4347
param_tensor = self.get_nodeattr("tensor_attr")
@@ -70,6 +74,7 @@ def test_attr():
7074
strarr = np.array2string(w, separator=", ")
7175
w_str = strarr.replace("[", "{").replace("]", "}").replace(" ", "")
7276
tensor_attr_str = f"int8{wshp_str} {w_str}"
77+
strings_attr = ["a", "bc", "def"]
7378

7479
input = f"""
7580
<
@@ -86,9 +91,18 @@ def test_attr():
8691
model = oprs.parse_model(input)
8792
model = ModelWrapper(model)
8893
inst = getCustomOp(model.graph.node[0])
94+
8995
w_prod = inst.get_nodeattr("tensor_attr")
9096
assert (w_prod == w).all()
9197
w = w - 1
9298
inst.set_nodeattr("tensor_attr", w)
9399
w_prod = inst.get_nodeattr("tensor_attr")
94100
assert (w_prod == w).all()
101+
102+
inst.set_nodeattr("strings_attr", strings_attr)
103+
strings_attr_prod = inst.get_nodeattr("strings_attr")
104+
assert strings_attr_prod == strings_attr
105+
strings_attr_prod[0] = "test"
106+
inst.set_nodeattr("strings_attr", strings_attr_prod)
107+
assert inst.get_nodeattr("strings_attr") == ["test"] + strings_attr[1:]
108+

0 commit comments

Comments
 (0)