Skip to content

Commit 431d758

Browse files
authored
Merge pull request #105 from bioimage-io/fixed_shape_support
Fixed shape support
2 parents 5d02803 + 12b82cf commit 431d758

File tree

12 files changed

+136
-170
lines changed

12 files changed

+136
-170
lines changed

bioimageio/core/prediction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ def predict_with_tiling(prediction_pipeline: PredictionPipeline, inputs, tiling)
295295
):
296296
raise NotImplementedError("Tiling with a different output shape is not yet supported")
297297

298-
ref_input = named_inputs[output_spec.shape.reference_input]
298+
ref_input = named_inputs[output_spec.shape.reference_tensor]
299299
ref_input_shape = dict(zip(ref_input.dims, ref_input.shape))
300300
output_shape = tuple(int(scale[ax] * ref_input_shape[ax] + 2 * offset[ax]) for ax in output_spec.axes)
301301
else:

bioimageio/core/prediction_pipeline/_prediction_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def create_prediction_pipeline(
257257
elif isinstance(out.shape, ImplicitOutputShape):
258258
named_output_shape.append(
259259
NamedImplicitOutputShape(
260-
reference_input=out.shape.reference_input,
260+
reference_input=out.shape.reference_tensor,
261261
scale=list(zip(out.axes, out.shape.scale)),
262262
offset=list(zip(out.axes, out.shape.offset)),
263263
)

bioimageio/core/resource_io/nodes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class URI(Node, raw_nodes.URI):
2727

2828

2929
@dataclass
30-
class ImplicitInputShape(Node, raw_nodes.ImplicitInputShape):
30+
class ParametrizedInputShape(Node, raw_nodes.ParametrizedInputShape):
3131
pass
3232

3333

tests/build_spec/test_build_spec.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import os
22

3-
import pytest
4-
5-
import bioimageio.spec as spec
63
from marshmallow import missing
74

5+
import bioimageio.spec as spec
86
from bioimageio.core.resource_io.io_ import load_raw_resource_description
97

108

@@ -56,33 +54,27 @@ def _test_build_spec(path, weight_type, tensorflow_version=None):
5654
spec.model.schema.Model().dump(raw_model)
5755

5856

59-
@pytest.mark.skipif(pytest.skip_torch, reason="requires torch")
6057
def test_build_spec_pytorch(any_torch_model):
6158
_test_build_spec(any_torch_model, "pytorch_state_dict")
6259

6360

64-
@pytest.mark.skipif(pytest.skip_torch, reason="requires torch")
6561
def test_build_spec_torchscript(any_torchscript_model):
6662
_test_build_spec(any_torchscript_model, "pytorch_script")
6763

6864

69-
@pytest.mark.skipif(pytest.skip_onnx, reason="requires onnx")
7065
def test_build_spec_onnx(any_onnx_model):
7166
_test_build_spec(any_onnx_model, "onnx")
7267

7368

74-
@pytest.mark.skipif(pytest.skip_tensorflow or pytest.tf_major_version != 1, reason="requires tensorflow 1")
7569
def test_build_spec_keras(any_keras_model):
7670
_test_build_spec(any_keras_model, "keras_hdf5", tensorflow_version="1.12") # todo: keras for tf 2??
7771

7872

79-
@pytest.mark.skipif(pytest.skip_tensorflow, reason="requires tensorflow")
8073
def test_build_spec_tf(any_tensorflow_model):
8174
_test_build_spec(
8275
any_tensorflow_model, "tensorflow_saved_model_bundle", tensorflow_version="1.12"
8376
) # check tf version
8477

8578

86-
@pytest.mark.skipif(pytest.skip_tensorflow, reason="requires tensorflow")
8779
def test_build_spec_tfjs(any_tensorflow_js_model):
8880
_test_build_spec(any_tensorflow_js_model, "tensorflow_js", tensorflow_version="1.12")

tests/conftest.py

Lines changed: 95 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -2,131 +2,152 @@
22
from bioimageio.core import export_resource_package
33

44
# test models for various frameworks
5-
torch_models = ["unet2d_nuclei_broad_model", "unet2d_multi_tensor"]
6-
torchscript_models = ["unet2d_nuclei_broad_model", "unet2d_multi_tensor"]
7-
onnx_models = ["unet2d_nuclei_broad_model", "unet2d_multi_tensor"]
5+
torch_models = ["unet2d_fixed_shape", "unet2d_multi_tensor", "unet2d_nuclei_broad_model"]
6+
torchscript_models = ["unet2d_multi_tensor", "unet2d_nuclei_broad_model"]
7+
onnx_models = ["unet2d_multi_tensor", "unet2d_nuclei_broad_model"]
88
tensorflow1_models = ["FruNet_model"]
99
tensorflow2_models = []
1010
keras_models = ["FruNet_model"]
1111
tensorflow_js_models = ["FruNet_model"]
1212

1313
model_sources = {
14+
"FruNet_model": "https://sandbox.zenodo.org/record/894498/files/rdf.yaml",
15+
# "FruNet_model": "https://raw.githubusercontent.com/deepimagej/models/master/fru-net_sev_segmentation/model.yaml",
1416
"unet2d_nuclei_broad_model": (
1517
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
1618
"unet2d_nuclei_broad/rdf.yaml"
1719
),
20+
"unet2d_fixed_shape": (
21+
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
22+
"unet2d_fixed_shape/rdf.yaml"
23+
),
1824
"unet2d_multi_tensor": (
1925
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
2026
"unet2d_multi_tensor/rdf.yaml"
2127
),
22-
"FruNet_model": "https://sandbox.zenodo.org/record/894498/files/rdf.yaml",
23-
# "FruNet_model": "https://raw.githubusercontent.com/deepimagej/models/master/fru-net_sev_segmentation/model.yaml",
2428
}
2529

26-
# set 'skip_<FRAMEWORK>' flags as global pytest variables,
27-
# to deselect tests that require frameworks not available in current env
28-
def pytest_configure():
29-
try:
30-
import torch
31-
except ImportError:
32-
torch = None
33-
pytest.skip_torch = torch is None
34-
35-
try:
36-
import onnxruntime
37-
except ImportError:
38-
onnxruntime = None
39-
pytest.skip_onnx = onnxruntime is None
30+
try:
31+
import torch
32+
except ImportError:
33+
torch = None
34+
skip_torch = torch is None
35+
36+
try:
37+
import onnxruntime
38+
except ImportError:
39+
onnxruntime = None
40+
skip_onnx = onnxruntime is None
41+
42+
try:
43+
import tensorflow
44+
45+
tf_major_version = int(tensorflow.__version__.split(".")[0])
46+
except ImportError:
47+
tensorflow = None
48+
tf_major_version = None
49+
50+
skip_tensorflow = tensorflow is None
51+
skip_tensorflow = True # todo: update FruNet and remove this
52+
skip_tensorflow_js = True # todo: update FruNet and figure out how to test tensorflow_js weights in python
53+
54+
try:
55+
import keras
56+
except ImportError:
57+
keras = None
58+
skip_keras = keras is None
59+
skip_keras = True # FruNet requires update
60+
61+
# load all model packages we need for testing
62+
load_model_packages = set()
63+
if not skip_torch:
64+
load_model_packages |= set(torch_models + torchscript_models)
65+
66+
if not skip_onnx:
67+
load_model_packages |= set(onnx_models)
68+
69+
if not skip_tensorflow:
70+
load_model_packages |= set(keras_models)
71+
load_model_packages |= set(tensorflow_js_models)
72+
if tf_major_version == 1:
73+
load_model_packages |= set(tensorflow1_models)
74+
elif tf_major_version == 2:
75+
load_model_packages |= set(tensorflow2_models)
4076

41-
try:
42-
import tensorflow
43-
44-
pytest.tf_major_version = int(tensorflow.__version__.split(".")[0])
45-
except ImportError:
46-
tensorflow = None
47-
pytest.skip_tensorflow = tensorflow is None
48-
pytest.skip_tensorflow = True # todo: update FruNet and remove this
49-
50-
try:
51-
import keras
52-
except ImportError:
53-
keras = None
54-
pytest.skip_keras = keras is None
55-
pytest.skip_keras = True # FruNet requires update
56-
57-
# load all model packages we need for testing
58-
load_packages = {"unet2d_nuclei_broad_model"} # always load unet2d_nuclei_broad_model
59-
if not pytest.skip_torch:
60-
load_packages |= set(torch_models + torchscript_models)
61-
62-
if not pytest.skip_onnx:
63-
load_packages |= set(onnx_models)
64-
65-
if not pytest.skip_tensorflow:
66-
load_packages |= set(keras_models)
67-
load_packages |= set(tensorflow_js_models)
68-
if pytest.tf_major_version == 1:
69-
load_packages |= set(tensorflow1_models)
70-
elif pytest.tf_major_version == 2:
71-
load_packages |= set(tensorflow2_models)
72-
73-
pytest.model_packages = {name: export_resource_package(model_sources[name]) for name in load_packages}
7477

78+
def pytest_configure():
7579

76-
@pytest.fixture
77-
def unet2d_nuclei_broad_model():
78-
return pytest.model_packages["unet2d_nuclei_broad_model"]
80+
# explicit skip flags needed for some tests
81+
pytest.skip_torch = skip_torch
82+
pytest.skip_onnx = skip_onnx
7983

84+
# load all model packages used in tests
85+
pytest.model_packages = {name: export_resource_package(model_sources[name]) for name in load_model_packages}
8086

81-
@pytest.fixture
82-
def unet2d_multi_tensor():
83-
return pytest.model_packages["unet2d_multi_tensor"]
8487

88+
#
89+
# model groups of the form any_<weight format>_model that include all models providing a specific weight format
90+
#
8591

86-
@pytest.fixture
87-
def FruNet_model():
88-
return pytest.model_packages["FruNet_model"]
92+
# written as model group to automatically skip on missing torch
93+
@pytest.fixture(params=[] if skip_torch else ["unet2d_nuclei_broad_model"])
94+
def unet2d_nuclei_broad_model(request):
95+
return pytest.model_packages[request.param]
8996

9097

91-
#
92-
# model groups
93-
#
98+
# written as model group to automatically skip on missing tensorflow 1
99+
@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["FruNet_model"])
100+
def FruNet_model(request):
101+
return pytest.model_packages[request.param]
94102

95103

96-
@pytest.fixture(params=torch_models)
104+
@pytest.fixture(params=[] if skip_torch else torch_models)
97105
def any_torch_model(request):
98106
return pytest.model_packages[request.param]
99107

100108

101-
@pytest.fixture(params=torchscript_models)
109+
@pytest.fixture(params=[] if skip_torch else torchscript_models)
102110
def any_torchscript_model(request):
103111
return pytest.model_packages[request.param]
104112

105113

106-
@pytest.fixture(params=onnx_models)
114+
@pytest.fixture(params=[] if skip_onnx else onnx_models)
107115
def any_onnx_model(request):
108116
return pytest.model_packages[request.param]
109117

110118

111-
@pytest.fixture(params=set(tensorflow1_models) | set(tensorflow2_models))
119+
@pytest.fixture(params=[] if skip_tensorflow else (set(tensorflow1_models) | set(tensorflow2_models)))
112120
def any_tensorflow_model(request):
113121
name = request.param
114-
if (pytest.tf_major_version == 1 and name in tensorflow1_models) or (
115-
pytest.tf_major_version == 2 and name in tensorflow2_models
116-
):
122+
if (tf_major_version == 1 and name in tensorflow1_models) or (tf_major_version == 2 and name in tensorflow2_models):
117123
return pytest.model_packages[name]
118124

119125

120-
@pytest.fixture(params=keras_models)
126+
@pytest.fixture(params=[] if skip_keras else keras_models)
121127
def any_keras_model(request):
122128
return pytest.model_packages[request.param]
123129

124130

125-
@pytest.fixture(params=tensorflow_js_models)
131+
@pytest.fixture(params=[] if skip_tensorflow_js else tensorflow_js_models)
126132
def any_tensorflow_js_model(request):
127133
return pytest.model_packages[request.param]
128134

129135

130-
@pytest.fixture(params=model_sources.keys())
136+
# fixture to test with all models that should run in the current environment
137+
@pytest.fixture(params=load_model_packages)
131138
def any_model(request):
132139
return pytest.model_packages[request.param]
140+
141+
142+
#
143+
# temporary fixtures to test not with all, but only a manual selection of models
144+
# (models/functionality should be improved to get rid of this specific model group)
145+
#
146+
@pytest.fixture(params=[] if skip_torch else ["unet2d_nuclei_broad_model", "unet2d_fixed_shape"])
147+
def unet2d_fixed_shape_or_not(request):
148+
return pytest.model_packages[request.param]
149+
150+
151+
@pytest.fixture(params=[] if skip_torch else ["unet2d_nuclei_broad_model", "unet2d_multi_tensor"])
152+
def unet2d_multi_tensor_or_not(request):
153+
return pytest.model_packages[request.param]

tests/prediction_pipeline/test_prediction_pipeline.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import numpy as np
2-
import pytest
32
import xarray as xr
43
from numpy.testing import assert_array_almost_equal
54

@@ -31,31 +30,21 @@ def _test_prediction_pipeline(model_package, weight_format):
3130
assert_array_almost_equal(out, exp, decimal=4)
3231

3332

34-
@pytest.mark.skipif(pytest.skip_torch, reason="requires torch")
3533
def test_prediction_pipeline_torch(any_torch_model):
3634
_test_prediction_pipeline(any_torch_model, "pytorch_state_dict")
3735

3836

39-
@pytest.mark.skipif(pytest.skip_torch, reason="requires torch")
4037
def test_prediction_pipeline_torchscript(any_torchscript_model):
4138
_test_prediction_pipeline(any_torchscript_model, "pytorch_script")
4239

4340

44-
@pytest.mark.skipif(pytest.skip_onnx, reason="requires onnx")
4541
def test_prediction_pipeline_onnx(any_onnx_model):
4642
_test_prediction_pipeline(any_onnx_model, "onnx")
4743

4844

49-
@pytest.mark.skipif(pytest.skip_tensorflow or pytest.tf_major_version != 1, reason="requires tensorflow 1")
50-
def test_prediction_pipeline_tensorflow(any_tensorflow1_model):
51-
_test_prediction_pipeline(any_tensorflow1_model, "tensorflow_saved_model_bundle")
45+
def test_prediction_pipeline_tensorflow(any_tensorflow_model):
46+
_test_prediction_pipeline(any_tensorflow_model, "tensorflow_saved_model_bundle")
5247

5348

54-
@pytest.mark.skipif(pytest.skip_tensorflow or pytest.tf_major_version != 2, reason="requires tensorflow 2")
55-
def test_prediction_pipeline_tensorflow(any_tensorflow2_model):
56-
_test_prediction_pipeline(any_tensorflow2_model, "tensorflow_saved_model_bundle")
57-
58-
59-
@pytest.mark.skipif(pytest.skip_keras, reason="requires keras")
6049
def test_prediction_pipeline_keras(any_keras_model):
6150
_test_prediction_pipeline(any_keras_model, "keras_hdf5")

tests/resource_io/test_load_rdf.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,23 +38,20 @@ def test_load_non_valid_rdf_name_invalid_suffix():
3838
load_resource_description(spec_path)
3939

4040

41-
@pytest.mark.skipif(pytest.skip_torch, reason="requires pytorch")
42-
def test_load_raw_model(unet2d_nuclei_broad_model):
41+
def test_load_raw_model(any_model):
4342
from bioimageio.core import load_raw_resource_description
4443

45-
raw_model = load_raw_resource_description(unet2d_nuclei_broad_model)
44+
raw_model = load_raw_resource_description(any_model)
4645
assert raw_model
4746

4847

49-
@pytest.mark.skipif(pytest.skip_torch, reason="requires pytorch")
50-
def test_load_model(unet2d_nuclei_broad_model):
48+
def test_load_model(any_model):
5149
from bioimageio.core import load_resource_description
5250

53-
model = load_resource_description(unet2d_nuclei_broad_model)
51+
model = load_resource_description(any_model)
5452
assert model
5553

5654

57-
@pytest.mark.skipif(pytest.skip_torch, reason="requires pytorch")
5855
def test_load_model_with_abs_path_source(unet2d_nuclei_broad_model):
5956
from bioimageio.core.resource_io import load_raw_resource_description, load_resource_description
6057

@@ -65,7 +62,6 @@ def test_load_model_with_abs_path_source(unet2d_nuclei_broad_model):
6562
assert model
6663

6764

68-
@pytest.mark.skipif(pytest.skip_torch, reason="requires pytorch")
6965
def test_load_model_with_rel_path_source(unet2d_nuclei_broad_model):
7066
from bioimageio.core.resource_io import load_raw_resource_description, load_resource_description
7167

@@ -76,7 +72,6 @@ def test_load_model_with_rel_path_source(unet2d_nuclei_broad_model):
7672
assert model
7773

7874

79-
@pytest.mark.skipif(pytest.skip_torch, reason="requires pytorch")
8075
def test_load_model_with_abs_str_source(unet2d_nuclei_broad_model):
8176
from bioimageio.core.resource_io import load_raw_resource_description, load_resource_description
8277

@@ -87,7 +82,6 @@ def test_load_model_with_abs_str_source(unet2d_nuclei_broad_model):
8782
assert model
8883

8984

90-
@pytest.mark.skipif(pytest.skip_torch, reason="requires pytorch")
9185
def test_load_model_with_rel_str_source(unet2d_nuclei_broad_model):
9286
from bioimageio.core.resource_io import load_raw_resource_description, load_resource_description
9387

0 commit comments

Comments
 (0)