@@ -40,16 +40,14 @@ def test_model(
4040 decimal : Optional [int ] = None ,
4141) -> ValidationSummary :
4242 """Test model inference"""
43- precision_args = _handle_legacy_precision_args (
44- absolute_tolerance = absolute_tolerance ,
45- relative_tolerance = relative_tolerance ,
46- decimal = decimal ,
47- )
43+ # NOTE: `decimal` is a legacy argument and is handled in `_test_model_inference`
4844 return test_description (
4945 source ,
5046 weight_format = weight_format ,
5147 devices = devices ,
52- ** precision_args ,
48+ absolute_tolerance = absolute_tolerance ,
49+ relative_tolerance = relative_tolerance ,
50+ decimal = decimal ,
5351 expected_type = "model" ,
5452 )
5553
@@ -66,17 +64,15 @@ def test_description(
6664 expected_type : Optional [str ] = None ,
6765) -> ValidationSummary :
6866 """Test a bioimage.io resource dynamically, e.g. prediction of test tensors for models"""
69- precision_args = _handle_legacy_precision_args (
70- absolute_tolerance = absolute_tolerance ,
71- relative_tolerance = relative_tolerance ,
72- decimal = decimal ,
73- )
67+ # NOTE: `decimal` is a legacy argument and is handled in `_test_model_inference`
7468 rd = load_description_and_test (
7569 source ,
7670 format_version = format_version ,
7771 weight_format = weight_format ,
7872 devices = devices ,
79- ** precision_args ,
73+ absolute_tolerance = absolute_tolerance ,
74+ relative_tolerance = relative_tolerance ,
75+ decimal = decimal ,
8076 expected_type = expected_type ,
8177 )
8278 return rd .validation_summary
@@ -94,12 +90,7 @@ def load_description_and_test(
9490 expected_type : Optional [str ] = None ,
9591) -> Union [ResourceDescr , InvalidDescr ]:
9692 """Test RDF dynamically, e.g. model inference of test inputs"""
97- precision_args = _handle_legacy_precision_args (
98- absolute_tolerance = absolute_tolerance ,
99- relative_tolerance = relative_tolerance ,
100- decimal = decimal ,
101- )
102-
93+ # NOTE: `decimal` is a legacy argument and is handled in `_test_model_inference`
10394 if (
10495 isinstance (source , ResourceDescrBase )
10596 and format_version != "discover"
@@ -132,16 +123,9 @@ def load_description_and_test(
132123 else :
133124 weight_formats = [weight_format ]
134125 for w in weight_formats :
135- # Note: new_precision_args is created like this to avoid type check errors
136- new_precision_args : Dict [str , float ] = {}
137- new_precision_args ["absolute_tolerance" ] = precision_args .get (
138- "absolute_tolerance"
139- )
140- new_precision_args ["relative_tolerance" ] = precision_args .get (
141- "relative_tolerance"
126+ _test_model_inference (
127+ rd , w , devices , absolute_tolerance , relative_tolerance , decimal
142128 )
143-
144- _test_model_inference (rd , w , devices , ** new_precision_args )
145129 if not isinstance (rd , v0_4 .ModelDescr ):
146130 _test_model_inference_parametrized (rd , w , devices )
147131
@@ -157,11 +141,19 @@ def _test_model_inference(
157141 devices : Optional [List [str ]],
158142 absolute_tolerance : float ,
159143 relative_tolerance : float ,
144+ decimal : Optional [int ],
160145) -> None :
161146 test_name = "Reproduce test outputs from test inputs"
162147 logger .info ("starting '{}'" , test_name )
163148 error : Optional [str ] = None
164149 tb : List [str ] = []
150+
151+ precision_args = _handle_legacy_precision_args (
152+ absolute_tolerance = absolute_tolerance ,
153+ relative_tolerance = relative_tolerance ,
154+ decimal = decimal ,
155+ )
156+
165157 try :
166158 inputs = get_test_inputs (model )
167159 expected = get_test_outputs (model )
@@ -184,8 +176,8 @@ def _test_model_inference(
184176 np .testing .assert_allclose (
185177 res .data ,
186178 exp .data ,
187- rtol = relative_tolerance ,
188- atol = absolute_tolerance ,
179+ rtol = precision_args [ " relative_tolerance" ] ,
180+ atol = precision_args [ " absolute_tolerance" ] ,
189181 )
190182 except AssertionError as e :
191183 error = f"Output and expected output disagree:\n { e } "
@@ -396,19 +388,9 @@ def _test_expected_resource_type(
396388 )
397389
398390
399- class PrecisionArgs (TypedDict ):
400- """
401- Arguments, both deprecated and current, for setting the precision during validation.
402- """
403-
404- absolute_tolerance : float
405- relative_tolerance : float
406- decimal : Optional [int ]
407-
408-
409391def _handle_legacy_precision_args (
410392 absolute_tolerance : float , relative_tolerance : float , decimal : Optional [int ]
411- ) -> PrecisionArgs :
393+ ) -> Dict [ str , float ] :
412394 """
413395 Transform the precision arguments to conform with the current implementation.
414396
@@ -420,7 +402,6 @@ def _handle_legacy_precision_args(
420402 return {
421403 "absolute_tolerance" : absolute_tolerance ,
422404 "relative_tolerance" : relative_tolerance ,
423- "decimal" : decimal ,
424405 }
425406
426407 warnings .warn (
@@ -436,7 +417,6 @@ def _handle_legacy_precision_args(
436417 return {
437418 "absolute_tolerance" : 1.5 * 10 ** (- decimal ),
438419 "relative_tolerance" : 0 ,
439- "decimal" : None ,
440420 }
441421
442422
0 commit comments