Skip to content

Commit 7f51b19

Browse files
Update conftest, test prediction functions with any_model, fix issues in prediction impl
1 parent 2a228e6 commit 7f51b19

File tree

3 files changed

+36
-36
lines changed

3 files changed

+36
-36
lines changed

bioimageio/core/prediction.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,8 @@ def predict_with_tiling(prediction_pipeline: PredictionPipeline, inputs, tiling)
318318

319319

320320
def parse_padding(padding, model):
321+
if padding is None: # no padding
322+
return padding
321323
if len(model.inputs) > 1:
322324
raise NotImplementedError("Padding for multiple inputs not yet implemented")
323325

@@ -327,9 +329,7 @@ def parse_padding(padding, model):
327329
def check_padding(padding):
328330
assert all(k in pad_keys for k in padding.keys())
329331

330-
if padding is None: # no padding
331-
return padding
332-
elif isinstance(padding, dict): # pre-defined padding
332+
if isinstance(padding, dict): # pre-defined padding
333333
check_padding(padding)
334334
elif isinstance(padding, bool): # determine padding from spec
335335
if padding:
@@ -351,6 +351,8 @@ def check_padding(padding):
351351

352352

353353
def parse_tiling(tiling, model):
354+
if tiling is None: # no tiling
355+
return tiling
354356
if len(model.inputs) > 1:
355357
raise NotImplementedError("Tiling for multiple inputs not yet implemented")
356358

@@ -363,9 +365,7 @@ def parse_tiling(tiling, model):
363365
def check_tiling(tiling):
364366
assert "halo" in tiling and "tile" in tiling
365367

366-
if tiling is None: # no tiling
367-
return tiling
368-
elif isinstance(tiling, dict):
368+
if isinstance(tiling, dict):
369369
check_tiling(tiling)
370370
elif isinstance(tiling, bool):
371371
if tiling:

tests/conftest.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
torch_models_pre_3_10 = ["unet2d_fixed_shape", "unet2d_multi_tensor", "unet2d_nuclei_broad_model"]
1212
torchscript_models = ["unet2d_multi_tensor", "unet2d_nuclei_broad_model"]
1313
onnx_models = ["unet2d_multi_tensor", "unet2d_nuclei_broad_model", "hpa_densenet"]
14-
tensorflow1_models = ["FruNet_model", "stardist"]
14+
tensorflow1_models = ["stardist"]
1515
tensorflow2_models = []
16-
keras_models = ["FruNet_model"]
17-
tensorflow_js_models = ["FruNet_model"]
16+
keras_models = []
17+
tensorflow_js_models = []
1818

1919
model_sources = {
20-
"FruNet_model": "https://sandbox.zenodo.org/record/894498/files/rdf.yaml",
21-
# "FruNet_model": "https://raw.githubusercontent.com/deepimagej/models/master/fru-net_sev_segmentation/model.yaml",
20+
# TODO add unet2d_keras_tf from https://github.com/bioimage-io/spec-bioimage-io/pull/267
21+
# "unet2d_keras_tf": (""),
2222
"unet2d_nuclei_broad_model": (
2323
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
2424
"unet2d_nuclei_broad/rdf.yaml"
@@ -35,10 +35,12 @@
3535
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/hpa-densenet/rdf.yaml"
3636
),
3737
"stardist": (
38-
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/stardist_example_model/rdf.yaml"
38+
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models"
39+
"/stardist_example_model/rdf.yaml"
3940
),
4041
"stardist_wrong_shape": (
41-
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/stardist_example_model/rdf_wrong_shape.yaml"
42+
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
43+
"stardist_example_model/rdf_wrong_shape.yaml"
4244
),
4345
}
4446

@@ -50,7 +52,6 @@
5052
except ImportError:
5153
torch = None
5254
torch_version = None
53-
5455
skip_torch = torch is None
5556

5657
try:
@@ -66,17 +67,15 @@
6667
except ImportError:
6768
tensorflow = None
6869
tf_major_version = None
69-
7070
skip_tensorflow = tensorflow is None
71-
skip_tensorflow = True # todo: update FruNet and remove this
72-
skip_tensorflow_js = True # todo: update FruNet and figure out how to test tensorflow_js weights in python
71+
skip_tensorflow_js = True # TODO: add a tensorflow_js example model
7372

7473
try:
7574
import keras
7675
except ImportError:
7776
keras = None
7877
skip_keras = keras is None
79-
skip_keras = True # FruNet requires update
78+
skip_keras = True # TODO add unet2d_keras_tf to have a model for keras tests
8079

8180
# load all model packages we need for testing
8281
load_model_packages = set()
@@ -120,14 +119,14 @@ def unet2d_nuclei_broad_model(request):
120119

121120

122121
# written as model group to automatically skip on missing tensorflow 1
123-
@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["FruNet_model"])
124-
def FruNet_model(request):
122+
@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist_wrong_shape"])
123+
def stardist_wrong_shape(request):
125124
return pytest.model_packages[request.param]
126125

127126

128127
# written as model group to automatically skip on missing tensorflow 1
129-
@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist_wrong_shape"])
130-
def stardist_wrong_shape(request):
128+
@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist"])
129+
def stardist(request):
131130
return pytest.model_packages[request.param]
132131

133132

@@ -164,7 +163,9 @@ def any_tensorflow_js_model(request):
164163

165164

166165
# fixture to test with all models that should run in the current environment
167-
@pytest.fixture(params=load_model_packages)
166+
# we exclude stardist_wrong_shape here because it is not a valid model
167+
# and included only to test that validation for this model fails
168+
@pytest.fixture(params=load_model_packages - {"stardist_wrong_shape"})
168169
def any_model(request):
169170
return pytest.model_packages[request.param]
170171

tests/test_prediction.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,19 @@
88
from bioimageio.core.resource_io.nodes import Model
99

1010

11-
def test_test_model(unet2d_nuclei_broad_model):
11+
def test_test_model(any_model):
1212
from bioimageio.core.resource_tests import test_model
1313

14-
assert test_model(unet2d_nuclei_broad_model)
14+
assert test_model(any_model)
1515

1616

17-
def test_test_resource(unet2d_nuclei_broad_model):
17+
def test_test_resource(any_model):
1818
from bioimageio.core.resource_tests import test_resource
1919

20-
assert test_resource(unet2d_nuclei_broad_model)
20+
assert test_resource(any_model)
2121

2222

23-
def test_predict_image(unet2d_fixed_shape_or_not, tmpdir):
24-
any_model = unet2d_fixed_shape_or_not # todo: replace 'unet2d_fixed_shape_or_not' with 'any_model'
23+
def test_predict_image(any_model, tmpdir):
2524
from bioimageio.core.prediction import predict_image
2625

2726
spec = load_resource_description(any_model)
@@ -92,15 +91,14 @@ def check_result():
9291
check_result()
9392

9493

95-
# prediction with padding with the parameters above may not be suted for any model
94+
# prediction with padding with the parameters above may not be suited for any model
9695
# so we only run it for the pytorch unet2d here
9796
def test_predict_image_with_padding(unet2d_fixed_shape_or_not, tmp_path):
9897
_test_predict_with_padding(unet2d_fixed_shape_or_not, tmp_path)
9998

10099

101-
# TODO need stardist model
102-
# def test_predict_image_with_padding_channel_last(stardist_model, tmp_path):
103-
# _test_predict_with_padding(stardist_model, tmp_path)
100+
def test_predict_image_with_padding_channel_last(stardist, tmp_path):
101+
_test_predict_with_padding(stardist, tmp_path)
104102

105103

106104
def _test_predict_image_with_tiling(model, tmp_path):
@@ -132,13 +130,14 @@ def check_result():
132130
check_result()
133131

134132

133+
# prediction with tiling with the parameters above may not be suited for any model
134+
# so we only run it for the pytorch unet2d here
135135
def test_predict_image_with_tiling(unet2d_nuclei_broad_model, tmp_path):
136136
_test_predict_image_with_tiling(unet2d_nuclei_broad_model, tmp_path)
137137

138138

139-
# TODO need stardist model
140-
# def test_predict_image_with_tiling_channel_last(stardist_model, tmp_path):
141-
# _test_predict_image_with_tiling(stardist_model, tmp_path)
139+
def test_predict_image_with_tiling_channel_last(stardist, tmp_path):
140+
_test_predict_image_with_tiling(stardist, tmp_path)
142141

143142

144143
def test_predict_images(unet2d_nuclei_broad_model, tmp_path):

0 commit comments

Comments
 (0)