@@ -100,16 +100,18 @@ def test_model(
100100 absolute_tolerance : float = 1.5e-4 ,
101101 relative_tolerance : float = 1e-4 ,
102102 decimal : Optional [int ] = None ,
103+ * ,
104+ determinism : Literal ["seed_only" , "full" ] = "seed_only" ,
103105) -> ValidationSummary :
104106 """Test model inference"""
105- # NOTE: `decimal` is a legacy argument and is handled in `_test_model_inference`
106107 return test_description (
107108 source ,
108109 weight_format = weight_format ,
109110 devices = devices ,
110111 absolute_tolerance = absolute_tolerance ,
111112 relative_tolerance = relative_tolerance ,
112113 decimal = decimal ,
114+ determinism = determinism ,
113115 expected_type = "model" ,
114116 )
115117
@@ -123,10 +125,10 @@ def test_description(
123125 absolute_tolerance : float = 1.5e-4 ,
124126 relative_tolerance : float = 1e-4 ,
125127 decimal : Optional [int ] = None ,
128+ determinism : Literal ["seed_only" , "full" ] = "seed_only" ,
126129 expected_type : Optional [str ] = None ,
127130) -> ValidationSummary :
128131 """Test a bioimage.io resource dynamically, e.g. prediction of test tensors for models"""
129- # NOTE: `decimal` is a legacy argument and is handled in `_test_model_inference`
130132 rd = load_description_and_test (
131133 source ,
132134 format_version = format_version ,
@@ -135,6 +137,7 @@ def test_description(
135137 absolute_tolerance = absolute_tolerance ,
136138 relative_tolerance = relative_tolerance ,
137139 decimal = decimal ,
140+ determinism = determinism ,
138141 expected_type = expected_type ,
139142 )
140143 return rd .validation_summary
@@ -149,10 +152,10 @@ def load_description_and_test(
149152 absolute_tolerance : float = 1.5e-4 ,
150153 relative_tolerance : float = 1e-4 ,
151154 decimal : Optional [int ] = None ,
155+ determinism : Literal ["seed_only" , "full" ] = "seed_only" ,
152156 expected_type : Optional [str ] = None ,
153157) -> Union [ResourceDescr , InvalidDescr ]:
154158 """Test RDF dynamically, e.g. model inference of test inputs"""
155- # NOTE: `decimal` is a legacy argument and is handled in `_test_model_inference`
156159 if (
157160 isinstance (source , ResourceDescrBase )
158161 and format_version != "discover"
@@ -184,10 +187,25 @@ def load_description_and_test(
184187 ] # pyright: ignore[reportAssignmentType]
185188 else :
186189 weight_formats = [weight_format ]
187- for w in weight_formats :
188- _test_model_inference (
189- rd , w , devices , absolute_tolerance , relative_tolerance , decimal
190+
191+ if decimal is None :
192+ atol = absolute_tolerance
193+ rtol = relative_tolerance
194+ else :
195+ warnings .warn (
196+ "The argument `decimal` has been deprecated in favour of"
197+ + " `relative_tolerance` and `absolute_tolerance`, with different"
198+ + " validation logic, using `numpy.testing.assert_allclose, see"
199+ + " 'https://numpy.org/doc/stable/reference/generated/"
200+ + " numpy.testing.assert_allclose.html'. Passing a value for `decimal`"
201+ + " will cause validation to revert to the old behaviour."
190202 )
203+ atol = 1.5 * 10 ** (- decimal )
204+ rtol = 0
205+
206+ enable_determinism (determinism )
207+ for w in weight_formats :
208+ _test_model_inference (rd , w , devices , atol , rtol )
191209 if not isinstance (rd , v0_4 .ModelDescr ):
192210 _test_model_inference_parametrized (rd , w , devices )
193211
@@ -201,21 +219,14 @@ def _test_model_inference(
201219 model : Union [v0_4 .ModelDescr , v0_5 .ModelDescr ],
202220 weight_format : WeightsFormat ,
203221 devices : Optional [Sequence [str ]],
204- absolute_tolerance : float ,
205- relative_tolerance : float ,
206- decimal : Optional [int ],
222+ atol : float ,
223+ rtol : float ,
207224) -> None :
208225 test_name = f"Reproduce test outputs from test inputs ({ weight_format } )"
209226 logger .info ("starting '{}'" , test_name )
210227 error : Optional [str ] = None
211228 tb : List [str ] = []
212229
213- precision_args = _handle_legacy_precision_args (
214- absolute_tolerance = absolute_tolerance ,
215- relative_tolerance = relative_tolerance ,
216- decimal = decimal ,
217- )
218-
219230 try :
220231 inputs = get_test_inputs (model )
221232 expected = get_test_outputs (model )
@@ -238,8 +249,8 @@ def _test_model_inference(
238249 np .testing .assert_allclose (
239250 res .data ,
240251 exp .data ,
241- rtol = precision_args [ "relative_tolerance" ] ,
242- atol = precision_args [ "absolute_tolerance" ] ,
252+ rtol = rtol ,
253+ atol = atol ,
243254 )
244255 except AssertionError as e :
245256 error = f"Output and expected output disagree:\n { e } "
@@ -455,39 +466,6 @@ def _test_expected_resource_type(
455466 )
456467
457468
458- def _handle_legacy_precision_args (
459- absolute_tolerance : float , relative_tolerance : float , decimal : Optional [int ]
460- ) -> Dict [str , float ]:
461- """
462- Transform the precision arguments to conform with the current implementation.
463-
464- If the deprecated `decimal` argument is used it overrides the new behaviour with
465- the old behaviour.
466- """
467- # Already conforms with current implementation
468- if decimal is None :
469- return {
470- "absolute_tolerance" : absolute_tolerance ,
471- "relative_tolerance" : relative_tolerance ,
472- }
473- else :
474- warnings .warn (
475- "The argument `decimal` has been depricated in favour of "
476- + "`relative_tolerance` and `absolute_tolerance`, with different validation "
477- + "logic, using `numpy.testing.assert_allclose, see "
478- + "'https://numpy.org/doc/stable/reference/generated/"
479- + "numpy.testing.assert_allclose.html'. Passing a value for `decimal` will "
480- + "cause validation to revert to the old behaviour."
481- )
482-
483- # decimal overrides new behaviour,
484- # have to convert the params to emulate old behaviour
485- return {
486- "absolute_tolerance" : 1.5 * 10 ** (- decimal ),
487- "relative_tolerance" : 0 ,
488- }
489-
490-
491469# TODO: Implement `debug_model()`
492470# def debug_model(
493471# model_rdf: Union[RawResourceDescr, ResourceDescr, URI, Path, str],
0 commit comments