Skip to content

Commit da181ea

Browse files
committed
refac: handle legacy args only in _test_model_inference
1 parent 9e3a3fe commit da181ea

File tree

1 file changed

+22
-42
lines changed

1 file changed

+22
-42
lines changed

bioimageio/core/_resource_tests.py

Lines changed: 22 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,14 @@ def test_model(
4040
decimal: Optional[int] = None,
4141
) -> ValidationSummary:
4242
"""Test model inference"""
43-
precision_args = _handle_legacy_precision_args(
44-
absolute_tolerance=absolute_tolerance,
45-
relative_tolerance=relative_tolerance,
46-
decimal=decimal,
47-
)
43+
# NOTE: `decimal` is a legacy argument and is handled in `_test_model_inference`
4844
return test_description(
4945
source,
5046
weight_format=weight_format,
5147
devices=devices,
52-
**precision_args,
48+
absolute_tolerance=absolute_tolerance,
49+
relative_tolerance=relative_tolerance,
50+
decimal=decimal,
5351
expected_type="model",
5452
)
5553

@@ -66,17 +64,15 @@ def test_description(
6664
expected_type: Optional[str] = None,
6765
) -> ValidationSummary:
6866
"""Test a bioimage.io resource dynamically, e.g. prediction of test tensors for models"""
69-
precision_args = _handle_legacy_precision_args(
70-
absolute_tolerance=absolute_tolerance,
71-
relative_tolerance=relative_tolerance,
72-
decimal=decimal,
73-
)
67+
# NOTE: `decimal` is a legacy argument and is handled in `_test_model_inference`
7468
rd = load_description_and_test(
7569
source,
7670
format_version=format_version,
7771
weight_format=weight_format,
7872
devices=devices,
79-
**precision_args,
73+
absolute_tolerance=absolute_tolerance,
74+
relative_tolerance=relative_tolerance,
75+
decimal=decimal,
8076
expected_type=expected_type,
8177
)
8278
return rd.validation_summary
@@ -94,12 +90,7 @@ def load_description_and_test(
9490
expected_type: Optional[str] = None,
9591
) -> Union[ResourceDescr, InvalidDescr]:
9692
"""Test RDF dynamically, e.g. model inference of test inputs"""
97-
precision_args = _handle_legacy_precision_args(
98-
absolute_tolerance=absolute_tolerance,
99-
relative_tolerance=relative_tolerance,
100-
decimal=decimal,
101-
)
102-
93+
# NOTE: `decimal` is a legacy argument and is handled in `_test_model_inference`
10394
if (
10495
isinstance(source, ResourceDescrBase)
10596
and format_version != "discover"
@@ -132,16 +123,9 @@ def load_description_and_test(
132123
else:
133124
weight_formats = [weight_format]
134125
for w in weight_formats:
135-
# Note: new_precision_args is created like this to avoid type check errors
136-
new_precision_args: Dict[str, float] = {}
137-
new_precision_args["absolute_tolerance"] = precision_args.get(
138-
"absolute_tolerance"
139-
)
140-
new_precision_args["relative_tolerance"] = precision_args.get(
141-
"relative_tolerance"
126+
_test_model_inference(
127+
rd, w, devices, absolute_tolerance, relative_tolerance, decimal
142128
)
143-
144-
_test_model_inference(rd, w, devices, **new_precision_args)
145129
if not isinstance(rd, v0_4.ModelDescr):
146130
_test_model_inference_parametrized(rd, w, devices)
147131

@@ -157,11 +141,19 @@ def _test_model_inference(
157141
devices: Optional[List[str]],
158142
absolute_tolerance: float,
159143
relative_tolerance: float,
144+
decimal: Optional[int],
160145
) -> None:
161146
test_name = "Reproduce test outputs from test inputs"
162147
logger.info("starting '{}'", test_name)
163148
error: Optional[str] = None
164149
tb: List[str] = []
150+
151+
precision_args = _handle_legacy_precision_args(
152+
absolute_tolerance=absolute_tolerance,
153+
relative_tolerance=relative_tolerance,
154+
decimal=decimal,
155+
)
156+
165157
try:
166158
inputs = get_test_inputs(model)
167159
expected = get_test_outputs(model)
@@ -184,8 +176,8 @@ def _test_model_inference(
184176
np.testing.assert_allclose(
185177
res.data,
186178
exp.data,
187-
rtol=relative_tolerance,
188-
atol=absolute_tolerance,
179+
rtol=precision_args["relative_tolerance"],
180+
atol=precision_args["absolute_tolerance"],
189181
)
190182
except AssertionError as e:
191183
error = f"Output and expected output disagree:\n {e}"
@@ -396,19 +388,9 @@ def _test_expected_resource_type(
396388
)
397389

398390

399-
class PrecisionArgs(TypedDict):
400-
"""
401-
Arguments, both deprecated and current, for setting the precision during validation.
402-
"""
403-
404-
absolute_tolerance: float
405-
relative_tolerance: float
406-
decimal: Optional[int]
407-
408-
409391
def _handle_legacy_precision_args(
410392
absolute_tolerance: float, relative_tolerance: float, decimal: Optional[int]
411-
) -> PrecisionArgs:
393+
) -> Dict[str, float]:
412394
"""
413395
Transform the precision arguments to conform with the current implementation.
414396
@@ -420,7 +402,6 @@ def _handle_legacy_precision_args(
420402
return {
421403
"absolute_tolerance": absolute_tolerance,
422404
"relative_tolerance": relative_tolerance,
423-
"decimal": decimal,
424405
}
425406

426407
warnings.warn(
@@ -436,7 +417,6 @@ def _handle_legacy_precision_args(
436417
return {
437418
"absolute_tolerance": 1.5 * 10 ** (-decimal),
438419
"relative_tolerance": 0,
439-
"decimal": None,
440420
}
441421

442422

0 commit comments

Comments
 (0)