Skip to content

Commit 8b99cee

Browse files
authored
Merge pull request #163 from bioimage-io/fix-shape-val
Fix shape val
2 parents 4e1babf + 82d352e commit 8b99cee

File tree

7 files changed

+106
-73
lines changed

7 files changed

+106
-73
lines changed

bioimageio/core/__main__.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import enum
22
import json
33
import os
4+
import sys
45
from glob import glob
56

67
from pathlib import Path
@@ -53,9 +54,10 @@ def package(
5354
# typer bug: typer returns empty tuple instead of None if weights_order_priority is not given
5455
weights_priority_order = weights_priority_order or None
5556

56-
return commands.package(
57+
ret_code = commands.package(
5758
rdf_source=rdf_source, path=path, weights_priority_order=weights_priority_order, verbose=verbose
5859
)
60+
sys.exit(ret_code)
5961

6062

6163
package.__doc__ = commands.package.__doc__
@@ -89,11 +91,12 @@ def test_model(
8991
)
9092
if summary["error"] is None:
9193
print(f"Model test for {model_rdf} has passed.")
92-
return 0
94+
ret_code = 0
9395
else:
9496
print(f"Model test for {model_rdf} has FAILED!")
9597
print(summary)
96-
return 1
98+
ret_code = 1
99+
sys.exit(ret_code)
97100

98101

99102
test_model.__doc__ = resource_tests.test_model.__doc__
@@ -116,11 +119,12 @@ def test_resource(
116119
)
117120
if summary["error"] is None:
118121
print(f"Resource test for {rdf} has passed.")
119-
return 0
122+
ret_code = 0
120123
else:
121124
print(f"Resource test for {rdf} has FAILED!")
122125
print(summary)
123-
return 1
126+
ret_code = 1
127+
sys.exit(ret_code)
124128

125129

126130
test_resource.__doc__ = resource_tests.test_resource.__doc__
@@ -159,7 +163,6 @@ def predict_image(
159163
prediction.predict_image(
160164
model_rdf, inputs, outputs, padding, tiling, None if weight_format is None else weight_format.value, devices
161165
)
162-
return 0
163166

164167

165168
predict_image.__doc__ = prediction.predict_image.__doc__
@@ -211,7 +214,6 @@ def predict_images(
211214
devices=devices,
212215
verbose=True,
213216
)
214-
return 0
215217

216218

217219
predict_images.__doc__ = prediction.predict_images.__doc__
@@ -241,7 +243,8 @@ def convert_torch_weights_to_torchscript(
241243
output_path: Path = typer.Argument(..., help="Where to save the torchscript weights."),
242244
use_tracing: bool = typer.Option(True, help="Whether to use torch.jit tracing or scripting."),
243245
) -> int:
244-
return torch_converter.convert_weights_to_pytorch_script(model_rdf, output_path, use_tracing)
246+
ret_code = torch_converter.convert_weights_to_pytorch_script(model_rdf, output_path, use_tracing)
247+
sys.exit(ret_code)
245248

246249
convert_torch_weights_to_torchscript.__doc__ = torch_converter.convert_weights_to_pytorch_script.__doc__
247250

bioimageio/core/commands.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ def package(
3232

3333
try:
3434
rdf_local_source = resolve_uri(rdf_source)
35+
except Exception as e:
36+
print(f"Failed to resolve RDF source {rdf_source}: {e}")
37+
if verbose:
38+
traceback.print_exc()
39+
return 1
40+
41+
try:
3542
path = path.with_name(path.name.format(src_name=rdf_local_source.stem))
3643
shutil.move(tmp_package_path, path)
3744
except Exception as e:

bioimageio/core/resource_tests.py

Lines changed: 41 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,41 @@ def test_model(
3838
return {"error": f"Expected RDF type Model, got {type(model)} instead.", "traceback": None}
3939

4040

41+
def _validate_input_shape(shape: Tuple[int, ...], shape_spec) -> bool:
42+
if isinstance(shape_spec, list):
43+
if shape != tuple(shape_spec):
44+
return False
45+
elif isinstance(shape_spec, ParametrizedInputShape):
46+
assert len(shape_spec.min) == len(shape_spec.step)
47+
if len(shape) != len(shape_spec.min):
48+
return False
49+
min_shape = shape_spec.min
50+
step = shape_spec.step
51+
# check if the shape is valid for all dimension by seeing if it can be reached with an integer number of steps
52+
# NOTE we allow that the valid shape is reached using a different number of steps for each axis here
53+
# this is usually valid because dimensions are independent in neural networks
54+
is_valid = [(sh - minsh) % st == 0 if st > 0 else sh == minsh for sh, st, minsh in zip(shape, step, min_shape)]
55+
return all(is_valid)
56+
else:
57+
raise TypeError(f"Encountered unexpected shape description of type {type(shape_spec)}")
58+
59+
return True
60+
61+
62+
def _validate_output_shape(shape: Tuple[int, ...], shape_spec, input_shapes) -> bool:
63+
if isinstance(shape_spec, list):
64+
return shape == tuple(shape_spec)
65+
elif isinstance(shape_spec, ImplicitOutputShape):
66+
ipt_shape = numpy.array(input_shapes[shape_spec.reference_tensor])
67+
scale = numpy.array(shape_spec.scale)
68+
offset = numpy.array(shape_spec.offset)
69+
exp_shape = numpy.round_(ipt_shape * scale) + 2 * offset
70+
71+
return shape == tuple(exp_shape)
72+
else:
73+
raise TypeError(f"Encountered unexpected shape description of type {type(shape_spec)}")
74+
75+
4176
def test_resource(
4277
model_rdf: Union[RawResourceDescription, ResourceDescription, URI, Path, str],
4378
*,
@@ -64,65 +99,22 @@ def test_resource(
6499
inputs = [np.load(str(in_path)) for in_path in model.test_inputs]
65100
expected = [np.load(str(out_path)) for out_path in model.test_outputs]
66101

67-
# check if test data shapes match their description
68-
input_shapes = [ipt.shape for ipt in inputs]
69-
output_shapes = [out.shape for out in expected]
70-
71-
def input_shape_is_valid(shape: Tuple[int, ...], shape_spec) -> bool:
72-
if isinstance(shape_spec, list):
73-
if shape != tuple(shape_spec):
74-
return False
75-
elif isinstance(shape_spec, ParametrizedInputShape):
76-
assert len(shape_spec.min) == len(shape_spec.step)
77-
if len(shape) != len(shape_spec.min):
78-
return False
79-
80-
valid_shape = numpy.array(shape_spec.min)
81-
step = numpy.array(shape_spec.step)
82-
if (step == 0).all():
83-
return shape == tuple(valid_shape)
84-
85-
shape = numpy.array(shape)
86-
while (shape <= valid_shape).all():
87-
if (shape == valid_shape).all():
88-
break
89-
90-
shape += step
91-
else:
92-
return False
93-
94-
else:
95-
raise TypeError(f"Encountered unexpected shape description of type {type(shape_spec)}")
96-
97-
return True
98-
99102
assert len(inputs) == len(model.inputs) # should be checked by validation
100103
input_shapes = {}
101104
for idx, (ipt, ipt_spec) in enumerate(zip(inputs, model.inputs)):
102-
if not input_shape_is_valid(ipt, ipt_spec.shape):
105+
if not _validate_input_shape(tuple(ipt.shape), ipt_spec.shape):
103106
raise ValidationError(
104-
f"Shape of test input {idx} '{ipt_spec.name}' does not match "
105-
f"input shape description: {ipt_spec.shape}"
107+
f"Shape {tuple(ipt.shape)} of test input {idx} '{ipt_spec.name}' does not match "
108+
f"input shape description: {ipt_spec.shape}."
106109
)
107110
input_shapes[ipt_spec.name] = ipt.shape
108111

109-
def output_shape_is_valid(shape: Tuple[int, ...], shape_spec) -> bool:
110-
if isinstance(shape_spec, list):
111-
return shape == tuple(shape_spec)
112-
elif isinstance(shape_spec, ImplicitOutputShape):
113-
ipt_shape = numpy.array(input_shapes[shape_spec.reference_tensor])
114-
scale = numpy.array(shape_spec.scale)
115-
offset = numpy.array(shape_spec.offset)
116-
exp_shape = numpy.round_(ipt_shape * scale) + 2 * offset
117-
118-
return shape == tuple(exp_shape)
119-
120112
assert len(expected) == len(model.outputs) # should be checked by validation
121113
for idx, (out, out_spec) in enumerate(zip(expected, model.outputs)):
122-
if not output_shape_is_valid(out, out_spec.shape):
114+
if not _validate_output_shape(tuple(out.shape), out_spec.shape, input_shapes):
123115
error = (error or "") + (
124-
f"Shape of test output {idx} '{out_spec.name}' does not match "
125-
f"output shape description: {out_spec.shape}.\n"
116+
f"Shape {tuple(out.shape)} of test output {idx} '{out_spec.name}' does not match "
117+
f"output shape description: {out_spec.shape}."
126118
)
127119

128120
with create_prediction_pipeline(

tests/conftest.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@
4242
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
4343
"stardist_example_model/rdf_wrong_shape.yaml"
4444
),
45+
"stardist_wrong_shape2": (
46+
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
47+
"stardist_example_model/rdf_wrong_shape2.yaml"
48+
),
4549
}
4650

4751
try:
@@ -94,6 +98,7 @@
9498
if tf_major_version == 1:
9599
load_model_packages |= set(tensorflow1_models)
96100
load_model_packages.add("stardist_wrong_shape")
101+
load_model_packages.add("stardist_wrong_shape2")
97102
elif tf_major_version == 2:
98103
load_model_packages |= set(tensorflow2_models)
99104

@@ -124,6 +129,12 @@ def stardist_wrong_shape(request):
124129
return pytest.model_packages[request.param]
125130

126131

132+
# written as model group to automatically skip on missing tensorflow 1
133+
@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist_wrong_shape2"])
134+
def stardist_wrong_shape2(request):
135+
return pytest.model_packages[request.param]
136+
137+
127138
# written as model group to automatically skip on missing tensorflow 1
128139
@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist"])
129140
def stardist(request):
@@ -165,7 +176,7 @@ def any_tensorflow_js_model(request):
165176
# fixture to test with all models that should run in the current environment
166177
# we exclude stardist_wrong_shape here because it is not a valid model
167178
# and included only to test that validation for this model fails
168-
@pytest.fixture(params=load_model_packages - {"stardist_wrong_shape"})
179+
@pytest.fixture(params=load_model_packages - {"stardist_wrong_shape", "stardist_wrong_shape2"})
169180
def any_model(request):
170181
return pytest.model_packages[request.param]
171182

tests/test_cli.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ def test_cli_test_model(unet2d_nuclei_broad_model):
2121
assert ret.returncode == 0
2222

2323

24+
def test_cli_test_model_fail(stardist_wrong_shape):
25+
ret = subprocess.run(["bioimageio", "test-model", stardist_wrong_shape])
26+
assert ret.returncode == 1
27+
28+
2429
def test_cli_test_model_with_weight_format(unet2d_nuclei_broad_model):
2530
ret = subprocess.run(
2631
["bioimageio", "test-model", unet2d_nuclei_broad_model, "--weight-format", "pytorch_state_dict"]

tests/test_prediction.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,6 @@
88
from bioimageio.core.resource_io.nodes import Model
99

1010

11-
def test_test_model(any_model):
12-
from bioimageio.core.resource_tests import test_model
13-
14-
assert test_model(any_model)
15-
16-
17-
def test_test_resource(any_model):
18-
from bioimageio.core.resource_tests import test_resource
19-
20-
assert test_resource(any_model)
21-
22-
2311
def test_predict_image(any_model, tmpdir):
2412
from bioimageio.core.prediction import predict_image
2513

tests/test_resource_tests/test_test_model.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,34 @@ def test_error_for_wrong_shape(stardist_wrong_shape):
22
from bioimageio.core.resource_tests import test_model
33

44
summary = test_model(stardist_wrong_shape)
5-
assert (
6-
summary["error"]
7-
== "Shape of test input 0 'input' does not match input shape description: ParametrizedInputShape(min=[1, 16, 16, 1], step=[0, 16, 16, 0])"
5+
expected_error_message = (
6+
"Shape (1, 512, 512, 33) of test output 0 'output' does not match output shape description: "
7+
"ImplicitOutputShape(reference_tensor='input', "
8+
"scale=[1.0, 1.0, 1.0, 0.0], offset=[1.0, 1.0, 1.0, 33.0])."
89
)
10+
assert summary["error"] == expected_error_message
11+
12+
13+
def test_error_for_wrong_shape2(stardist_wrong_shape2):
14+
from bioimageio.core.resource_tests import test_model
15+
16+
summary = test_model(stardist_wrong_shape2)
17+
expected_error_message = (
18+
"Shape (1, 512, 512, 1) of test input 0 'input' does not match input shape description: "
19+
"ParametrizedInputShape(min=[1, 16, 16, 1], step=[0, 17, 17, 0])."
20+
)
21+
assert summary["error"] == expected_error_message
22+
23+
24+
def test_test_model(any_model):
25+
from bioimageio.core.resource_tests import test_model
26+
27+
summary = test_model(any_model)
28+
assert summary["error"] is None
29+
30+
31+
def test_test_resource(any_model):
32+
from bioimageio.core.resource_tests import test_resource
33+
34+
summary = test_resource(any_model)
35+
assert summary["error"] is None

0 commit comments

Comments
 (0)