Skip to content

Commit 3ec678b

Browse files
authored
Merge pull request #232 from LinusJungemann/fix/nodeAttrWrongType
Add type checks to set_nodeattr
2 parents 6ea382f + af33396 commit 3ec678b

File tree

2 files changed

+234
-9
lines changed

2 files changed

+234
-9
lines changed

src/qonnx/custom_op/base.py

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2727
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2828

29+
import numpy as np
2930
import onnx.helper as helper
3031
import onnx.numpy_helper as np_helper
3132
from abc import ABC, abstractmethod
@@ -73,7 +74,9 @@ def get_nodeattr_def(self, name):
7374
elif len(attrdef) == 4:
7475
(dtype, req, def_val, allowed_values) = attrdef
7576
else:
76-
raise Exception("Unexpected length %d n-tuple from get_nodeattr_types" % len(attrdef))
77+
raise Exception(
78+
"Unexpected length %d n-tuple from get_nodeattr_types" % len(attrdef)
79+
)
7780
return (dtype, req, def_val, allowed_values)
7881

7982
def get_nodeattr_allowed_values(self, name):
@@ -129,15 +132,57 @@ def set_nodeattr(self, name, value):
129132
try:
130133
(dtype, req, def_val, allowed_values) = self.get_nodeattr_def(name)
131134
if allowed_values is not None:
132-
assert value in allowed_values, "%s = %s not in %s" % (
133-
str(name),
134-
str(value),
135-
str(allowed_values),
136-
)
135+
if value not in allowed_values:
136+
raise ValueError(
137+
"%s = %s not in %s"
138+
% (str(name), str(value), str(allowed_values))
139+
)
137140
attr = get_by_name(self.onnx_node.attribute, name)
138-
if dtype == "t":
139-
# convert numpy array to TensorProto
141+
142+
# Verify value type matches dtype before setting/converting
143+
if dtype == "i":
144+
if not isinstance(value, int):
145+
raise TypeError(f"Attribute {name} expects int, got {type(value)}")
146+
elif dtype == "f":
147+
if not isinstance(value, float):
148+
raise TypeError(
149+
f"Attribute {name} expects float, got {type(value)}"
150+
)
151+
elif dtype == "s":
152+
if not isinstance(value, (str, bytes)):
153+
raise TypeError(f"Attribute {name} expects str, got {type(value)}")
154+
elif dtype == "ints":
155+
if not (
156+
isinstance(value, list) and all(isinstance(v, int) for v in value)
157+
):
158+
raise TypeError(
159+
f"Attribute {name} expects list of ints, got {type(value)}"
160+
)
161+
elif dtype == "floats":
162+
if not (
163+
isinstance(value, list)
164+
and all(isinstance(v, (int, float)) for v in value)
165+
):
166+
raise TypeError(
167+
f"Attribute {name} expects list of floats, got {type(value)}"
168+
)
169+
elif dtype == "strings":
170+
if not (
171+
isinstance(value, list)
172+
and all(isinstance(v, (str, bytes)) for v in value)
173+
):
174+
raise TypeError(
175+
f"Attribute {name} expects list of strings, got {type(value)}"
176+
)
177+
elif dtype == "t":
178+
# Validate that value is a numpy array
179+
if not isinstance(value, (np.ndarray, np.generic)):
180+
raise TypeError(
181+
f"Attribute {name} expects numpy array, got {type(value)}"
182+
)
183+
# Convert numpy array to TensorProto
140184
value = np_helper.from_array(value)
185+
141186
if attr is not None:
142187
# dtype indicates which ONNX Attribute member to use
143188
# (such as i, f, s...)

tests/custom_op/test_attr.py

Lines changed: 181 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,20 @@
2828

2929
import numpy as np
3030
import onnx.parser as oprs
31+
import pytest
3132

3233
from qonnx.core.modelwrapper import ModelWrapper
3334
from qonnx.custom_op.base import CustomOp
3435
from qonnx.custom_op.registry import add_op_to_domain, getCustomOp
36+
from onnx import helper
3537

3638

3739
class AttrTestOp(CustomOp):
3840
def get_nodeattr_types(self):
39-
my_attrs = {"tensor_attr": ("t", True, np.asarray([])), "strings_attr": ("strings", True, [""])}
41+
my_attrs = {
42+
"tensor_attr": ("t", True, np.asarray([])),
43+
"strings_attr": ("strings", True, [""]),
44+
}
4045
return my_attrs
4146

4247
def make_shape_compatible_op(self, model):
@@ -101,3 +106,178 @@ def test_attr():
101106
strings_attr_prod[0] = "test"
102107
inst.set_nodeattr("strings_attr", strings_attr_prod)
103108
assert inst.get_nodeattr("strings_attr") == ["test"] + strings_attr[1:]
109+
110+
111+
class MyCustomOp(CustomOp):
112+
def __init__(self, onnx_node, onnx_opset_version=10):
113+
super().__init__(onnx_node, onnx_opset_version)
114+
115+
def get_nodeattr_types(self):
116+
return {
117+
"my_int_attr": ("i", False, -1),
118+
"my_float_attr": ("f", False, 0.0),
119+
"my_string_attr": ("s", False, "default"),
120+
"my_ints_attr": ("ints", False, []),
121+
"my_floats_attr": ("floats", False, []),
122+
"my_strings_attr": ("strings", False, []),
123+
"my_allowed_attr": ("i", False, 1, {1, 2, 3}),
124+
"my_tensor_attr": ("t", False, np.array([])),
125+
}
126+
127+
def execute_node(self, context, graph):
128+
pass
129+
130+
def infer_node_datatype(self, model):
131+
pass
132+
133+
def make_shape_compatible_op(self, model):
134+
pass
135+
136+
def verify_node(self):
137+
pass
138+
139+
140+
def test_set_get_nodeattr():
141+
node = helper.make_node("myOpType", [], [])
142+
myCustomOp = MyCustomOp(node, 13)
143+
144+
# Test integer attribute
145+
assert myCustomOp.get_nodeattr("my_int_attr") == -1
146+
myCustomOp.set_nodeattr("my_int_attr", 2)
147+
assert myCustomOp.get_nodeattr("my_int_attr") == 2
148+
149+
# Test that setting wrong type raises TypeError
150+
with pytest.raises(TypeError, match="expects int"):
151+
myCustomOp.set_nodeattr("my_int_attr", 2.5)
152+
with pytest.raises(TypeError, match="expects int"):
153+
myCustomOp.set_nodeattr("my_int_attr", "string")
154+
155+
# Test float attribute
156+
assert myCustomOp.get_nodeattr("my_float_attr") == 0.0
157+
myCustomOp.set_nodeattr("my_float_attr", 3.14)
158+
assert abs(myCustomOp.get_nodeattr("my_float_attr") - 3.14) < 1e-6
159+
160+
with pytest.raises(TypeError, match="expects float"):
161+
myCustomOp.set_nodeattr("my_float_attr", 42)
162+
with pytest.raises(TypeError, match="expects float"):
163+
myCustomOp.set_nodeattr("my_float_attr", "string")
164+
165+
# Test string attribute
166+
assert myCustomOp.get_nodeattr("my_string_attr") == "default"
167+
myCustomOp.set_nodeattr("my_string_attr", "test_value")
168+
assert myCustomOp.get_nodeattr("my_string_attr") == "test_value"
169+
170+
with pytest.raises(TypeError, match="expects str"):
171+
myCustomOp.set_nodeattr("my_string_attr", 123)
172+
with pytest.raises(TypeError, match="expects str"):
173+
myCustomOp.set_nodeattr("my_string_attr", 3.14)
174+
175+
# Test ints attribute
176+
assert myCustomOp.get_nodeattr("my_ints_attr") == []
177+
myCustomOp.set_nodeattr("my_ints_attr", [1, 2, 3])
178+
assert myCustomOp.get_nodeattr("my_ints_attr") == [1, 2, 3]
179+
180+
with pytest.raises(TypeError, match="expects list of ints"):
181+
myCustomOp.set_nodeattr("my_ints_attr", [1, 2.5, 3])
182+
with pytest.raises(TypeError, match="expects list of ints"):
183+
myCustomOp.set_nodeattr("my_ints_attr", [1, "two", 3])
184+
with pytest.raises(TypeError, match="expects list of ints"):
185+
myCustomOp.set_nodeattr("my_ints_attr", 123)
186+
187+
# Test floats attribute
188+
assert myCustomOp.get_nodeattr("my_floats_attr") == []
189+
myCustomOp.set_nodeattr("my_floats_attr", [1.0, 2.5, 3.14])
190+
result = myCustomOp.get_nodeattr("my_floats_attr")
191+
assert len(result) == 3
192+
assert abs(result[0] - 1.0) < 1e-6
193+
assert abs(result[1] - 2.5) < 1e-6
194+
assert abs(result[2] - 3.14) < 1e-6
195+
# floats can accept ints
196+
myCustomOp.set_nodeattr("my_floats_attr", [1, 2, 3])
197+
assert myCustomOp.get_nodeattr("my_floats_attr") == [1, 2, 3]
198+
199+
with pytest.raises(TypeError, match="expects list of floats"):
200+
myCustomOp.set_nodeattr("my_floats_attr", [1.0, "two", 3.0])
201+
with pytest.raises(TypeError, match="expects list of floats"):
202+
myCustomOp.set_nodeattr("my_floats_attr", 3.14)
203+
204+
# Test strings attribute
205+
assert myCustomOp.get_nodeattr("my_strings_attr") == []
206+
myCustomOp.set_nodeattr("my_strings_attr", ["a", "b", "c"])
207+
assert myCustomOp.get_nodeattr("my_strings_attr") == ["a", "b", "c"]
208+
209+
with pytest.raises(TypeError, match="expects list of strings"):
210+
myCustomOp.set_nodeattr("my_strings_attr", ["a", 2, "c"])
211+
with pytest.raises(TypeError, match="expects list of strings"):
212+
myCustomOp.set_nodeattr("my_strings_attr", "not a list")
213+
214+
# Test allowed_values validation
215+
assert myCustomOp.get_nodeattr("my_allowed_attr") == 1
216+
myCustomOp.set_nodeattr("my_allowed_attr", 2)
217+
assert myCustomOp.get_nodeattr("my_allowed_attr") == 2
218+
myCustomOp.set_nodeattr("my_allowed_attr", 3)
219+
assert myCustomOp.get_nodeattr("my_allowed_attr") == 3
220+
221+
with pytest.raises(ValueError, match="not in"):
222+
myCustomOp.set_nodeattr("my_allowed_attr", 5)
223+
224+
# Test tensor attribute (numpy arrays)
225+
default_tensor = myCustomOp.get_nodeattr("my_tensor_attr")
226+
assert default_tensor.shape == (0,)
227+
228+
# Set a 1D numpy array
229+
tensor_1d = np.array([1, 2, 3, 4, 5], dtype=np.int32)
230+
myCustomOp.set_nodeattr("my_tensor_attr", tensor_1d)
231+
result_1d = myCustomOp.get_nodeattr("my_tensor_attr")
232+
assert np.array_equal(result_1d, tensor_1d)
233+
assert result_1d.dtype == tensor_1d.dtype
234+
235+
# Set a 2D numpy array
236+
tensor_2d = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)
237+
myCustomOp.set_nodeattr("my_tensor_attr", tensor_2d)
238+
result_2d = myCustomOp.get_nodeattr("my_tensor_attr")
239+
assert np.array_equal(result_2d, tensor_2d)
240+
assert result_2d.shape == tensor_2d.shape
241+
242+
# Set a 3D numpy array with different dtype
243+
tensor_3d = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.int8)
244+
myCustomOp.set_nodeattr("my_tensor_attr", tensor_3d)
245+
result_3d = myCustomOp.get_nodeattr("my_tensor_attr")
246+
assert np.array_equal(result_3d, tensor_3d)
247+
assert result_3d.shape == (2, 2, 2)
248+
assert result_3d.dtype == np.int8
249+
250+
# Test assigning numpy arrays to non-tensor attributes (should fail)
251+
numpy_arr = np.array([1, 2, 3])
252+
with pytest.raises(TypeError, match="expects int"):
253+
myCustomOp.set_nodeattr("my_int_attr", numpy_arr)
254+
with pytest.raises(TypeError, match="expects float"):
255+
myCustomOp.set_nodeattr("my_float_attr", numpy_arr)
256+
with pytest.raises(TypeError, match="expects str"):
257+
myCustomOp.set_nodeattr("my_string_attr", numpy_arr)
258+
with pytest.raises(TypeError, match="expects list of ints"):
259+
myCustomOp.set_nodeattr("my_ints_attr", numpy_arr)
260+
with pytest.raises(TypeError, match="expects list of floats"):
261+
myCustomOp.set_nodeattr("my_floats_attr", numpy_arr)
262+
with pytest.raises(TypeError, match="expects list of strings"):
263+
myCustomOp.set_nodeattr("my_strings_attr", numpy_arr)
264+
265+
# Test assigning non-numpy values to tensor attribute (should fail or convert)
266+
# Scalars should fail
267+
with pytest.raises((TypeError, AttributeError)):
268+
myCustomOp.set_nodeattr("my_tensor_attr", 42)
269+
with pytest.raises((TypeError, AttributeError)):
270+
myCustomOp.set_nodeattr("my_tensor_attr", 3.14)
271+
with pytest.raises((TypeError, AttributeError)):
272+
myCustomOp.set_nodeattr("my_tensor_attr", "string")
273+
274+
# Test assigning lists to tensor attribute (should fail with TypeError)
275+
# Plain lists are not accepted - must be numpy arrays
276+
with pytest.raises(TypeError, match="expects numpy array"):
277+
myCustomOp.set_nodeattr("my_tensor_attr", [10, 20, 30, 40])
278+
279+
with pytest.raises(TypeError, match="expects numpy array"):
280+
myCustomOp.set_nodeattr("my_tensor_attr", [[1, 2, 3], [4, 5, 6]])
281+
282+
with pytest.raises(TypeError, match="expects numpy array"):
283+
myCustomOp.set_nodeattr("my_tensor_attr", [1.5, 2.5, 3.5])

0 commit comments

Comments
 (0)