@@ -184,3 +184,54 @@ def func(a):
184
184
185
185
actual = func (1 )
186
186
assert actual is None
187
+
188
+ @pytest .mark .parametrize (
189
+ [
190
+ "return_value_units" ,
191
+ "multiple_units" ,
192
+ "errors" ,
193
+ "multiple_errors" ,
194
+ "message" ,
195
+ ],
196
+ (
197
+ (
198
+ ("m" , "s" ),
199
+ False ,
200
+ ValueError ,
201
+ False ,
202
+ "mismatched number of return values" ,
203
+ ),
204
+ ("m" , True , ValueError , False , "mismatched number of return values" ),
205
+ (("m" ,), True , ValueError , False , "mismatched number of return values" ),
206
+ (1 , False , TypeError , True , "units must be of type" ),
207
+ ),
208
+ )
209
+ def test_return_value_errors (
210
+ self , return_value_units , multiple_units , errors , multiple_errors , message
211
+ ):
212
+ if multiple_errors :
213
+ root_error = ExceptionGroup
214
+ root_message = "Errors while converting return values"
215
+ else :
216
+ root_error = errors
217
+ root_message = message
218
+
219
+ with pytest .raises (root_error , match = root_message ) as excinfo :
220
+
221
+ @pint_xarray .expects (a = None , b = None , return_value = return_value_units )
222
+ def func (a , b ):
223
+ if multiple_units :
224
+ return a , b
225
+ else :
226
+ return a / b
227
+
228
+ func (1 , 2 )
229
+
230
+ if not multiple_errors :
231
+ return
232
+
233
+ group = excinfo .value
234
+ assert len (group .exceptions ) == 1 , f"Found { len (group .exceptions )} exceptions"
235
+ exc = group .exceptions [0 ]
236
+ if not re .search (message , str (exc )):
237
+ raise AssertionError (f"exception { exc !r} did not match pattern { message !r} " )
0 commit comments