Skip to content

Commit 2edac77

Browse files
Update model test to handle expanded output shape and add test model
1 parent 7af0e0f commit 2edac77

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

bioimageio/core/resource_tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def _validate_output_shape(shape: Tuple[int, ...], shape_spec, input_shapes) ->
8989
if ref_tensor not in input_shapes:
9090
raise ValidationError(f"The reference tensor name {ref_tensor} is not in {input_shapes}")
9191
ipt_shape = numpy.array(input_shapes[ref_tensor])
92-
scale = numpy.array(shape_spec.scale)
92+
scale = numpy.array([0.0 if sc is None else sc for sc in shape_spec.scale])
9393
offset = numpy.array(shape_spec.offset)
9494
exp_shape = numpy.round_(ipt_shape * scale) + 2 * offset
9595

tests/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@
3232
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
3333
"unet2d_nuclei_broad/rdf.yaml"
3434
),
35+
"unet2d_expand_output_shape": (
36+
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
37+
"unet2d_nuclei_broad/rdf_expand_output_shape.yaml"
38+
),
3539
"unet2d_fixed_shape": (
3640
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
3741
"unet2d_fixed_shape/rdf.yaml"
@@ -205,6 +209,12 @@ def unet2d_diff_output_shape(request):
205209
return pytest.model_packages[request.param]
206210

207211

212+
# written as model group to automatically skip on missing torch
213+
@pytest.fixture(params=[] if skip_torch else ["unet2d_expand_output_shape"])
214+
def unet2d_expand_output_shape(request):
215+
return pytest.model_packages[request.param]
216+
217+
208218
# written as model group to automatically skip on missing torch
209219
@pytest.fixture(params=[] if skip_torch else ["unet2d_fixed_shape"])
210220
def unet2d_fixed_shape(request):

0 commit comments

Comments
 (0)