Skip to content

Commit 324f15c

Browse files
committed
use determinism and refactor decimal
1 parent a1b6f5c commit 324f15c

File tree

1 file changed

+28
-50
lines changed

1 file changed

+28
-50
lines changed

bioimageio/core/_resource_tests.py

Lines changed: 28 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -100,16 +100,18 @@ def test_model(
100100
absolute_tolerance: float = 1.5e-4,
101101
relative_tolerance: float = 1e-4,
102102
decimal: Optional[int] = None,
103+
*,
104+
determinism: Literal["seed_only", "full"] = "seed_only",
103105
) -> ValidationSummary:
104106
"""Test model inference"""
105-
# NOTE: `decimal` is a legacy argument and is handled in `_test_model_inference`
106107
return test_description(
107108
source,
108109
weight_format=weight_format,
109110
devices=devices,
110111
absolute_tolerance=absolute_tolerance,
111112
relative_tolerance=relative_tolerance,
112113
decimal=decimal,
114+
determinism=determinism,
113115
expected_type="model",
114116
)
115117

@@ -123,10 +125,10 @@ def test_description(
123125
absolute_tolerance: float = 1.5e-4,
124126
relative_tolerance: float = 1e-4,
125127
decimal: Optional[int] = None,
128+
determinism: Literal["seed_only", "full"] = "seed_only",
126129
expected_type: Optional[str] = None,
127130
) -> ValidationSummary:
128131
"""Test a bioimage.io resource dynamically, e.g. prediction of test tensors for models"""
129-
# NOTE: `decimal` is a legacy argument and is handled in `_test_model_inference`
130132
rd = load_description_and_test(
131133
source,
132134
format_version=format_version,
@@ -135,6 +137,7 @@ def test_description(
135137
absolute_tolerance=absolute_tolerance,
136138
relative_tolerance=relative_tolerance,
137139
decimal=decimal,
140+
determinism=determinism,
138141
expected_type=expected_type,
139142
)
140143
return rd.validation_summary
@@ -149,10 +152,10 @@ def load_description_and_test(
149152
absolute_tolerance: float = 1.5e-4,
150153
relative_tolerance: float = 1e-4,
151154
decimal: Optional[int] = None,
155+
determinism: Literal["seed_only", "full"] = "seed_only",
152156
expected_type: Optional[str] = None,
153157
) -> Union[ResourceDescr, InvalidDescr]:
154158
"""Test RDF dynamically, e.g. model inference of test inputs"""
155-
# NOTE: `decimal` is a legacy argument and is handled in `_test_model_inference`
156159
if (
157160
isinstance(source, ResourceDescrBase)
158161
and format_version != "discover"
@@ -184,10 +187,25 @@ def load_description_and_test(
184187
] # pyright: ignore[reportAssignmentType]
185188
else:
186189
weight_formats = [weight_format]
187-
for w in weight_formats:
188-
_test_model_inference(
189-
rd, w, devices, absolute_tolerance, relative_tolerance, decimal
190+
191+
if decimal is None:
192+
atol = absolute_tolerance
193+
rtol = relative_tolerance
194+
else:
195+
warnings.warn(
196+
"The argument `decimal` has been deprecated in favour of"
197+
+ " `relative_tolerance` and `absolute_tolerance`, with different"
198+
+ " validation logic, using `numpy.testing.assert_allclose, see"
199+
+ " 'https://numpy.org/doc/stable/reference/generated/"
200+
+ " numpy.testing.assert_allclose.html'. Passing a value for `decimal`"
201+
+ " will cause validation to revert to the old behaviour."
190202
)
203+
atol = 1.5 * 10 ** (-decimal)
204+
rtol = 0
205+
206+
enable_determinism(determinism)
207+
for w in weight_formats:
208+
_test_model_inference(rd, w, devices, atol, rtol)
191209
if not isinstance(rd, v0_4.ModelDescr):
192210
_test_model_inference_parametrized(rd, w, devices)
193211

@@ -201,21 +219,14 @@ def _test_model_inference(
201219
model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
202220
weight_format: WeightsFormat,
203221
devices: Optional[Sequence[str]],
204-
absolute_tolerance: float,
205-
relative_tolerance: float,
206-
decimal: Optional[int],
222+
atol: float,
223+
rtol: float,
207224
) -> None:
208225
test_name = f"Reproduce test outputs from test inputs ({weight_format})"
209226
logger.info("starting '{}'", test_name)
210227
error: Optional[str] = None
211228
tb: List[str] = []
212229

213-
precision_args = _handle_legacy_precision_args(
214-
absolute_tolerance=absolute_tolerance,
215-
relative_tolerance=relative_tolerance,
216-
decimal=decimal,
217-
)
218-
219230
try:
220231
inputs = get_test_inputs(model)
221232
expected = get_test_outputs(model)
@@ -238,8 +249,8 @@ def _test_model_inference(
238249
np.testing.assert_allclose(
239250
res.data,
240251
exp.data,
241-
rtol=precision_args["relative_tolerance"],
242-
atol=precision_args["absolute_tolerance"],
252+
rtol=rtol,
253+
atol=atol,
243254
)
244255
except AssertionError as e:
245256
error = f"Output and expected output disagree:\n {e}"
@@ -455,39 +466,6 @@ def _test_expected_resource_type(
455466
)
456467

457468

458-
def _handle_legacy_precision_args(
459-
absolute_tolerance: float, relative_tolerance: float, decimal: Optional[int]
460-
) -> Dict[str, float]:
461-
"""
462-
Transform the precision arguments to conform with the current implementation.
463-
464-
If the deprecated `decimal` argument is used it overrides the new behaviour with
465-
the old behaviour.
466-
"""
467-
# Already conforms with current implementation
468-
if decimal is None:
469-
return {
470-
"absolute_tolerance": absolute_tolerance,
471-
"relative_tolerance": relative_tolerance,
472-
}
473-
else:
474-
warnings.warn(
475-
"The argument `decimal` has been depricated in favour of "
476-
+ "`relative_tolerance` and `absolute_tolerance`, with different validation "
477-
+ "logic, using `numpy.testing.assert_allclose, see "
478-
+ "'https://numpy.org/doc/stable/reference/generated/"
479-
+ "numpy.testing.assert_allclose.html'. Passing a value for `decimal` will "
480-
+ "cause validation to revert to the old behaviour."
481-
)
482-
483-
# decimal overrides new behaviour,
484-
# have to convert the params to emulate old behaviour
485-
return {
486-
"absolute_tolerance": 1.5 * 10 ** (-decimal),
487-
"relative_tolerance": 0,
488-
}
489-
490-
491469
# TODO: Implement `debug_model()`
492470
# def debug_model(
493471
# model_rdf: Union[RawResourceDescr, ResourceDescr, URI, Path, str],

0 commit comments

Comments
 (0)