@@ -35,13 +35,18 @@ def test_model(
3535 source : Union [v0_5 .ModelDescr , PermissiveFileSource ],
3636 weight_format : Optional [WeightsFormat ] = None ,
3737 devices : Optional [List [str ]] = None ,
38- decimal : int = 4 ,
38+ absolute_tolerance : float = 1.5e-4 ,
39+ relative_tolerance : float = 1e-4 ,
40+ decimal : Optional [int ] = None ,
3941) -> ValidationSummary :
4042 """Test model inference"""
43+ # NOTE: `decimal` is a legacy argument and is handled in `_test_model_inference`
4144 return test_description (
4245 source ,
4346 weight_format = weight_format ,
4447 devices = devices ,
48+ absolute_tolerance = absolute_tolerance ,
49+ relative_tolerance = relative_tolerance ,
4550 decimal = decimal ,
4651 expected_type = "model" ,
4752 )
@@ -53,15 +58,20 @@ def test_description(
5358 format_version : Union [Literal ["discover" , "latest" ], str ] = "discover" ,
5459 weight_format : Optional [WeightsFormat ] = None ,
5560 devices : Optional [List [str ]] = None ,
56- decimal : int = 4 ,
61+ absolute_tolerance : float = 1.5e-4 ,
62+ relative_tolerance : float = 1e-4 ,
63+ decimal : Optional [int ] = None ,
5764 expected_type : Optional [str ] = None ,
5865) -> ValidationSummary :
5966 """Test a bioimage.io resource dynamically, e.g. prediction of test tensors for models"""
67+ # NOTE: `decimal` is a legacy argument and is handled in `_test_model_inference`
6068 rd = load_description_and_test (
6169 source ,
6270 format_version = format_version ,
6371 weight_format = weight_format ,
6472 devices = devices ,
73+ absolute_tolerance = absolute_tolerance ,
74+ relative_tolerance = relative_tolerance ,
6575 decimal = decimal ,
6676 expected_type = expected_type ,
6777 )
@@ -74,10 +84,13 @@ def load_description_and_test(
7484 format_version : Union [Literal ["discover" , "latest" ], str ] = "discover" ,
7585 weight_format : Optional [WeightsFormat ] = None ,
7686 devices : Optional [List [str ]] = None ,
77- decimal : int = 4 ,
87+ absolute_tolerance : float = 1.5e-4 ,
88+ relative_tolerance : float = 1e-4 ,
89+ decimal : Optional [int ] = None ,
7890 expected_type : Optional [str ] = None ,
7991) -> Union [ResourceDescr , InvalidDescr ]:
8092 """Test RDF dynamically, e.g. model inference of test inputs"""
93+ # NOTE: `decimal` is a legacy argument and is handled in `_test_model_inference`
8194 if (
8295 isinstance (source , ResourceDescrBase )
8396 and format_version != "discover"
@@ -110,7 +123,9 @@ def load_description_and_test(
110123 else :
111124 weight_formats = [weight_format ]
112125 for w in weight_formats :
113- _test_model_inference (rd , w , devices , decimal )
126+ _test_model_inference (
127+ rd , w , devices , absolute_tolerance , relative_tolerance , decimal
128+ )
114129 if not isinstance (rd , v0_4 .ModelDescr ):
115130 _test_model_inference_parametrized (rd , w , devices )
116131
@@ -124,12 +139,21 @@ def _test_model_inference(
124139 model : Union [v0_4 .ModelDescr , v0_5 .ModelDescr ],
125140 weight_format : WeightsFormat ,
126141 devices : Optional [List [str ]],
127- decimal : int ,
142+ absolute_tolerance : float ,
143+ relative_tolerance : float ,
144+ decimal : Optional [int ],
128145) -> None :
129146 test_name = "Reproduce test outputs from test inputs"
130147 logger .info ("starting '{}'" , test_name )
131148 error : Optional [str ] = None
132149 tb : List [str ] = []
150+
151+ precision_args = _handle_legacy_precision_args (
152+ absolute_tolerance = absolute_tolerance ,
153+ relative_tolerance = relative_tolerance ,
154+ decimal = decimal ,
155+ )
156+
133157 try :
134158 inputs = get_test_inputs (model )
135159 expected = get_test_outputs (model )
@@ -149,8 +173,11 @@ def _test_model_inference(
149173 error = "Output tensors for test case may not be None"
150174 break
151175 try :
152- np .testing .assert_array_almost_equal (
153- res .data , exp .data , decimal = decimal
176+ np .testing .assert_allclose (
177+ res .data ,
178+ exp .data ,
179+ rtol = precision_args ["relative_tolerance" ],
180+ atol = precision_args ["absolute_tolerance" ],
154181 )
155182 except AssertionError as e :
156183 error = f"Output and expected output disagree:\n { e } "
@@ -361,6 +388,38 @@ def _test_expected_resource_type(
361388 )
362389
363390
391+ def _handle_legacy_precision_args (
392+ absolute_tolerance : float , relative_tolerance : float , decimal : Optional [int ]
393+ ) -> Dict [str , float ]:
394+ """
395+ Transform the precision arguments to conform with the current implementation.
396+
397+ If the deprecated `decimal` argument is used it overrides the new behaviour with
398+ the old behaviour.
399+ """
400+ # Already conforms with current implementation
401+ if decimal is None :
402+ return {
403+ "absolute_tolerance" : absolute_tolerance ,
404+ "relative_tolerance" : relative_tolerance ,
405+ }
406+
407+ warnings .warn (
408+ "The argument `decimal` has been depricated in favour of "
409+ + "`relative_tolerance` and `absolute_tolerance`, with different validation "
410+ + "logic, using `numpy.testing.assert_allclose, see "
411+ + "'https://numpy.org/doc/stable/reference/generated/"
412+ + "numpy.testing.assert_allclose.html'. Passing a value for `decimal` will "
413+ + "cause validation to revert to the old behaviour."
414+ )
415+ # decimal overrides new behaviour,
416+ # have to convert the params to emulate old behaviour
417+ return {
418+ "absolute_tolerance" : 1.5 * 10 ** (- decimal ),
419+ "relative_tolerance" : 0 ,
420+ }
421+
422+
364423# def debug_model(
365424# model_rdf: Union[RawResourceDescr, ResourceDescr, URI, Path, str],
366425# *,
0 commit comments