Skip to content

Commit 0cb6aff

Browse files
committed
check for various errors when returning results
1 parent 19a5d72 commit 0cb6aff

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

pint_xarray/tests/test_expects.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,54 @@ def func(a):
184184

185185
actual = func(1)
186186
assert actual is None
187+
188+
@pytest.mark.parametrize(
189+
[
190+
"return_value_units",
191+
"multiple_units",
192+
"errors",
193+
"multiple_errors",
194+
"message",
195+
],
196+
(
197+
(
198+
("m", "s"),
199+
False,
200+
ValueError,
201+
False,
202+
"mismatched number of return values",
203+
),
204+
("m", True, ValueError, False, "mismatched number of return values"),
205+
(("m",), True, ValueError, False, "mismatched number of return values"),
206+
(1, False, TypeError, True, "units must be of type"),
207+
),
208+
)
209+
def test_return_value_errors(
210+
self, return_value_units, multiple_units, errors, multiple_errors, message
211+
):
212+
if multiple_errors:
213+
root_error = ExceptionGroup
214+
root_message = "Errors while converting return values"
215+
else:
216+
root_error = errors
217+
root_message = message
218+
219+
with pytest.raises(root_error, match=root_message) as excinfo:
220+
221+
@pint_xarray.expects(a=None, b=None, return_value=return_value_units)
222+
def func(a, b):
223+
if multiple_units:
224+
return a, b
225+
else:
226+
return a / b
227+
228+
func(1, 2)
229+
230+
if not multiple_errors:
231+
return
232+
233+
group = excinfo.value
234+
assert len(group.exceptions) == 1, f"Found {len(group.exceptions)} exceptions"
235+
exc = group.exceptions[0]
236+
if not re.search(message, str(exc)):
237+
raise AssertionError(f"exception {exc!r} did not match pattern {message!r}")

0 commit comments

Comments
 (0)