Skip to content

Commit cd543e0

Browse files
committed
always test all weight formats
1 parent 81ea7db commit cd543e0

File tree

1 file changed

+18
-23
lines changed

1 file changed

+18
-23
lines changed

bioimageio/core/_resource_tests.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,16 @@ def load_description_and_test(
103103
_test_expected_resource_type(rd, expected_type)
104104

105105
if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)):
106-
_test_model_inference(rd, weight_format, devices, decimal)
107-
if not isinstance(rd, v0_4.ModelDescr):
108-
_test_model_inference_parametrized(rd, weight_format, devices)
106+
if weight_format is None:
107+
weight_formats: List[WeightsFormat] = [
108+
w for w, we in rd.weights if we is not None
109+
] # pyright: ignore[reportAssignmentType]
110+
else:
111+
weight_formats = [weight_format]
112+
for w in weight_formats:
113+
_test_model_inference(rd, w, devices, decimal)
114+
if not isinstance(rd, v0_4.ModelDescr):
115+
_test_model_inference_parametrized(rd, w, devices)
109116

110117
# TODO: add execution of jupyter notebooks
111118
# TODO: add more tests
@@ -115,7 +122,7 @@ def load_description_and_test(
115122

116123
def _test_model_inference(
117124
model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
118-
weight_format: Optional[WeightsFormat],
125+
weight_format: WeightsFormat,
119126
devices: Optional[List[str]],
120127
decimal: int,
121128
) -> None:
@@ -161,11 +168,7 @@ def _test_model_inference(
161168
if error is None
162169
else [
163170
ErrorEntry(
164-
loc=(
165-
("weights",)
166-
if weight_format is None
167-
else ("weights", weight_format)
168-
),
171+
loc=("weights", weight_format),
169172
msg=error,
170173
type="bioimageio.core",
171174
traceback=tb,
@@ -178,7 +181,7 @@ def _test_model_inference(
178181

179182
def _test_model_inference_parametrized(
180183
model: v0_5.ModelDescr,
181-
weight_format: Optional[WeightsFormat],
184+
weight_format: WeightsFormat,
182185
devices: Optional[List[str]],
183186
) -> None:
184187
if not any(
@@ -300,19 +303,15 @@ def get_ns(n: int):
300303

301304
model.validation_summary.add_detail(
302305
ValidationDetail(
303-
name="Run inference for inputs with batch_size:"
304-
+ f" {batch_size} and size parameter n: {n}",
306+
name=f"Run {weight_format} inference for inputs with"
307+
+ f" batch_size: {batch_size} and size parameter n: {n}",
305308
status="passed" if error is None else "failed",
306309
errors=(
307310
[]
308311
if error is None
309312
else [
310313
ErrorEntry(
311-
loc=(
312-
("weights",)
313-
if weight_format is None
314-
else ("weights", weight_format)
315-
),
314+
loc=("weights", weight_format),
316315
msg=error,
317316
type="bioimageio.core",
318317
)
@@ -325,15 +324,11 @@ def get_ns(n: int):
325324
tb = traceback.format_tb(e.__traceback__)
326325
model.validation_summary.add_detail(
327326
ValidationDetail(
328-
name="Run inference for parametrized inputs",
327+
name=f"Run {weight_format} inference for parametrized inputs",
329328
status="failed",
330329
errors=[
331330
ErrorEntry(
332-
loc=(
333-
("weights",)
334-
if weight_format is None
335-
else ("weights", weight_format)
336-
),
331+
loc=("weights", weight_format),
337332
msg=error,
338333
type="bioimageio.core",
339334
traceback=tb,

0 commit comments

Comments
 (0)