Skip to content

Commit 4e1babf

Browse files
authored
Merge pull request #161 from bioimage-io/fix-tiling
Fix tiling
2 parents aaa4c22 + 1816b62 commit 4e1babf

File tree

4 files changed

+120
-50
lines changed

4 files changed

+120
-50
lines changed

bioimageio/core/prediction.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def load_tile(tile):
219219
inp = input_[tile]
220220
# whether to pad on the right or left of the dim for the spatial dims
221221
# + placeholders for batch and axis dimension, where we don't pad
222-
pad_right = [None, None] + [tile[ax].start == 0 for ax in input_axes if ax in "xyz"]
222+
pad_right = [tile[ax].start == 0 if ax in "xyz" else None for ax in input_axes]
223223
return inp, pad_right
224224

225225
# we need to use padded prediction for the individual tiles in case the
@@ -318,6 +318,8 @@ def predict_with_tiling(prediction_pipeline: PredictionPipeline, inputs, tiling)
318318

319319

320320
def parse_padding(padding, model):
321+
if padding is None: # no padding
322+
return padding
321323
if len(model.inputs) > 1:
322324
raise NotImplementedError("Padding for multiple inputs not yet implemented")
323325

@@ -327,9 +329,7 @@ def parse_padding(padding, model):
327329
def check_padding(padding):
328330
assert all(k in pad_keys for k in padding.keys())
329331

330-
if padding is None: # no padding
331-
return padding
332-
elif isinstance(padding, dict): # pre-defined padding
332+
if isinstance(padding, dict): # pre-defined padding
333333
check_padding(padding)
334334
elif isinstance(padding, bool): # determine padding from spec
335335
if padding:
@@ -350,7 +350,25 @@ def check_padding(padding):
350350
return padding
351351

352352

353+
# simple heuristic to determine suitable shape from min and step
354+
def _determine_shape(min_shape, step, axes):
355+
is3d = "z" in axes
356+
min_len = 64 if is3d else 256
357+
shape = []
358+
for ax, min_ax, step_ax in zip(axes, min_shape, step):
359+
if ax in "zyx" and step_ax > 0:
360+
len_ax = min_ax
361+
while len_ax < min_len:
362+
len_ax += step_ax
363+
shape.append(len_ax)
364+
else:
365+
shape.append(min_ax)
366+
return shape
367+
368+
353369
def parse_tiling(tiling, model):
370+
if tiling is None: # no tiling
371+
return tiling
354372
if len(model.inputs) > 1:
355373
raise NotImplementedError("Tiling for multiple inputs not yet implemented")
356374

@@ -359,13 +377,17 @@ def parse_tiling(tiling, model):
359377

360378
input_spec = model.inputs[0]
361379
output_spec = model.outputs[0]
380+
axes = input_spec.axes
362381

363382
def check_tiling(tiling):
364383
assert "halo" in tiling and "tile" in tiling
384+
spatial_axes = [ax for ax in axes if ax in "xyz"]
385+
halo = tiling["halo"]
386+
tile = tiling["tile"]
387+
assert all(halo.get(ax, 0) > 0 for ax in spatial_axes)
388+
assert all(tile.get(ax, 0) > 0 for ax in spatial_axes)
365389

366-
if tiling is None: # no tiling
367-
return tiling
368-
elif isinstance(tiling, dict):
390+
if isinstance(tiling, dict):
369391
check_tiling(tiling)
370392
elif isinstance(tiling, bool):
371393
if tiling:
@@ -374,18 +396,21 @@ def check_tiling(tiling):
374396
# output space and then request the corresponding input tiles
375397
# so we would need to apply the output scale and offset to the
376398
# input shape to compute the tile size and halo here
377-
axes = input_spec.axes
378399
shape = input_spec.shape
379400
if not isinstance(shape, list):
380-
# NOTE this might result in very small tiles.
381-
# it would be good to have some heuristic to determine a suitable tilesize
382-
# from shape.min and shape.step
383-
shape = shape.min
401+
shape = _determine_shape(shape.min, shape.step, axes)
402+
assert isinstance(shape, list)
403+
assert len(shape) == len(axes)
404+
384405
halo = output_spec.halo
406+
if halo is None:
407+
raise ValueError("Model does not provide a valid halo to use for tiling with default parameters")
408+
385409
tiling = {
386410
"halo": {ax: ha for ax, ha in zip(axes, halo) if ax in "xyz"},
387411
"tile": {ax: sh for ax, sh in zip(axes, shape) if ax in "xyz"},
388412
}
413+
check_tiling(tiling)
389414
else:
390415
tiling = None
391416
else:

dev/environment-torch.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@ dependencies:
1212
- pytest
1313
- python >=3.7
1414
- xarray
15-
- pytorch
15+
- pytorch <1.10
1616
- onnxruntime

tests/conftest.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
torch_models_pre_3_10 = ["unet2d_fixed_shape", "unet2d_multi_tensor", "unet2d_nuclei_broad_model"]
1212
torchscript_models = ["unet2d_multi_tensor", "unet2d_nuclei_broad_model"]
1313
onnx_models = ["unet2d_multi_tensor", "unet2d_nuclei_broad_model", "hpa_densenet"]
14-
tensorflow1_models = ["FruNet_model", "stardist"]
14+
tensorflow1_models = ["stardist"]
1515
tensorflow2_models = []
16-
keras_models = ["FruNet_model"]
17-
tensorflow_js_models = ["FruNet_model"]
16+
keras_models = []
17+
tensorflow_js_models = []
1818

1919
model_sources = {
20-
"FruNet_model": "https://sandbox.zenodo.org/record/894498/files/rdf.yaml",
21-
# "FruNet_model": "https://raw.githubusercontent.com/deepimagej/models/master/fru-net_sev_segmentation/model.yaml",
20+
# TODO add unet2d_keras_tf from https://github.com/bioimage-io/spec-bioimage-io/pull/267
21+
# "unet2d_keras_tf": (""),
2222
"unet2d_nuclei_broad_model": (
2323
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
2424
"unet2d_nuclei_broad/rdf.yaml"
@@ -35,10 +35,12 @@
3535
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/hpa-densenet/rdf.yaml"
3636
),
3737
"stardist": (
38-
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/stardist_example_model/rdf.yaml"
38+
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models"
39+
"/stardist_example_model/rdf.yaml"
3940
),
4041
"stardist_wrong_shape": (
41-
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/stardist_example_model/rdf_wrong_shape.yaml"
42+
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
43+
"stardist_example_model/rdf_wrong_shape.yaml"
4244
),
4345
}
4446

@@ -50,7 +52,6 @@
5052
except ImportError:
5153
torch = None
5254
torch_version = None
53-
5455
skip_torch = torch is None
5556

5657
try:
@@ -66,17 +67,15 @@
6667
except ImportError:
6768
tensorflow = None
6869
tf_major_version = None
69-
7070
skip_tensorflow = tensorflow is None
71-
skip_tensorflow = True # todo: update FruNet and remove this
72-
skip_tensorflow_js = True # todo: update FruNet and figure out how to test tensorflow_js weights in python
71+
skip_tensorflow_js = True # TODO: add a tensorflow_js example model
7372

7473
try:
7574
import keras
7675
except ImportError:
7776
keras = None
7877
skip_keras = keras is None
79-
skip_keras = True # FruNet requires update
78+
skip_keras = True # TODO add unet2d_keras_tf to have a model for keras tests
8079

8180
# load all model packages we need for testing
8281
load_model_packages = set()
@@ -120,14 +119,14 @@ def unet2d_nuclei_broad_model(request):
120119

121120

122121
# written as model group to automatically skip on missing tensorflow 1
123-
@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["FruNet_model"])
124-
def FruNet_model(request):
122+
@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist_wrong_shape"])
123+
def stardist_wrong_shape(request):
125124
return pytest.model_packages[request.param]
126125

127126

128127
# written as model group to automatically skip on missing tensorflow 1
129-
@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist_wrong_shape"])
130-
def stardist_wrong_shape(request):
128+
@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist"])
129+
def stardist(request):
131130
return pytest.model_packages[request.param]
132131

133132

@@ -164,7 +163,9 @@ def any_tensorflow_js_model(request):
164163

165164

166165
# fixture to test with all models that should run in the current environment
167-
@pytest.fixture(params=load_model_packages)
166+
# we exclude stardist_wrong_shape here because it is not a valid model
167+
# and included only to test that validation for this model fails
168+
@pytest.fixture(params=load_model_packages - {"stardist_wrong_shape"})
168169
def any_model(request):
169170
return pytest.model_packages[request.param]
170171

tests/test_prediction.py

Lines changed: 64 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,19 @@
88
from bioimageio.core.resource_io.nodes import Model
99

1010

11-
def test_test_model(unet2d_nuclei_broad_model):
11+
def test_test_model(any_model):
1212
from bioimageio.core.resource_tests import test_model
1313

14-
assert test_model(unet2d_nuclei_broad_model)
14+
assert test_model(any_model)
1515

1616

17-
def test_test_resource(unet2d_nuclei_broad_model):
17+
def test_test_resource(any_model):
1818
from bioimageio.core.resource_tests import test_resource
1919

20-
assert test_resource(unet2d_nuclei_broad_model)
20+
assert test_resource(any_model)
2121

2222

23-
def test_predict_image(unet2d_fixed_shape_or_not, tmpdir):
24-
any_model = unet2d_fixed_shape_or_not # todo: replace 'unet2d_fixed_shape_or_not' with 'any_model'
23+
def test_predict_image(any_model, tmpdir):
2524
from bioimageio.core.prediction import predict_image
2625

2726
spec = load_resource_description(any_model)
@@ -57,46 +56,81 @@ def test_predict_image_with_weight_format(unet2d_fixed_shape_or_not, tmpdir):
5756
assert_array_almost_equal(res, exp, decimal=4)
5857

5958

60-
def test_predict_image_with_padding(unet2d_fixed_shape_or_not, tmp_path):
61-
any_model = unet2d_fixed_shape_or_not # todo: replace 'unet2d_fixed_shape_or_not' with 'any_model'
59+
def _test_predict_with_padding(model, tmp_path):
6260
from bioimageio.core.prediction import predict_image
6361

64-
spec = load_resource_description(any_model)
62+
spec = load_resource_description(model)
6563
assert isinstance(spec, Model)
66-
image = np.load(str(spec.test_inputs[0]))[0, 0]
64+
65+
input_spec, output_spec = spec.inputs[0], spec.outputs[0]
66+
channel_axis = input_spec.axes.index("c")
67+
channel_first = channel_axis == 1
68+
69+
image = np.load(str(spec.test_inputs[0]))
70+
assert image.shape[channel_axis] == 1
71+
if channel_first:
72+
image = image[0, 0]
73+
else:
74+
image = image[0, ..., 0]
6775
original_shape = image.shape
6876
assert image.ndim == 2
6977

78+
if isinstance(output_spec.shape, list):
79+
n_channels = output_spec.shape[channel_axis]
80+
else:
81+
scale = output_spec.shape.scale[channel_axis]
82+
offset = output_spec.shape.offset[channel_axis]
83+
in_channels = 1
84+
n_channels = int(2 * offset + scale * in_channels)
85+
7086
# write the padded image
7187
image = image[3:-2, 1:-12]
7288
in_path = tmp_path / "in.tif"
7389
out_path = tmp_path / "out.tif"
7490
imageio.imwrite(in_path, image)
7591

7692
def check_result():
77-
assert out_path.exists()
78-
res = imageio.imread(out_path)
79-
assert res.shape == image.shape
93+
if n_channels == 1:
94+
assert out_path.exists()
95+
res = imageio.imread(out_path)
96+
assert res.shape == image.shape
97+
else:
98+
path = str(out_path)
99+
for c in range(n_channels):
100+
channel_out_path = Path(path.replace(".tif", f"-c{c}.tif"))
101+
assert channel_out_path.exists()
102+
res = imageio.imread(channel_out_path)
103+
assert res.shape == image.shape
80104

81105
# test with dynamic padding
82-
predict_image(any_model, in_path, out_path, padding={"x": 8, "y": 8, "mode": "dynamic"})
106+
predict_image(model, in_path, out_path, padding={"x": 16, "y": 16, "mode": "dynamic"})
83107
check_result()
84108

85109
# test with fixed padding
86110
predict_image(
87-
any_model, in_path, out_path, padding={"x": original_shape[0], "y": original_shape[1], "mode": "fixed"}
111+
model, in_path, out_path, padding={"x": original_shape[0], "y": original_shape[1], "mode": "fixed"}
88112
)
89113
check_result()
90114

91115
# test with automated padding
92-
predict_image(any_model, in_path, out_path, padding=True)
116+
predict_image(model, in_path, out_path, padding=True)
93117
check_result()
94118

95119

96-
def test_predict_image_with_tiling(unet2d_nuclei_broad_model, tmp_path):
120+
# prediction with padding with the parameters above may not be suited for any model
121+
# so we only run it for the pytorch unet2d here
122+
def test_predict_image_with_padding(unet2d_fixed_shape_or_not, tmp_path):
123+
_test_predict_with_padding(unet2d_fixed_shape_or_not, tmp_path)
124+
125+
126+
def test_predict_image_with_padding_channel_last(stardist, tmp_path):
127+
_test_predict_with_padding(stardist, tmp_path)
128+
129+
130+
def _test_predict_image_with_tiling(model, tmp_path):
97131
from bioimageio.core.prediction import predict_image
98132

99-
spec = load_resource_description(unet2d_nuclei_broad_model)
133+
spec = load_resource_description(model)
100134
assert isinstance(spec, Model)
101135
inputs = spec.test_inputs
102136
assert len(inputs) == 1
@@ -114,14 +148,24 @@ def check_result():
114148

115149
# with tiling config
116150
tiling = {"halo": {"x": 32, "y": 32}, "tile": {"x": 256, "y": 256}}
117-
predict_image(unet2d_nuclei_broad_model, inputs, [out_path], tiling=tiling)
151+
predict_image(model, inputs, [out_path], tiling=tiling)
118152
check_result()
119153

120154
# with tiling determined from spec
121-
predict_image(unet2d_nuclei_broad_model, inputs, [out_path], tiling=True)
155+
predict_image(model, inputs, [out_path], tiling=True)
122156
check_result()
123157

124158

159+
# prediction with tiling with the parameters above may not be suited for any model
160+
# so we only run it for the pytorch unet2d here
161+
def test_predict_image_with_tiling(unet2d_nuclei_broad_model, tmp_path):
162+
_test_predict_image_with_tiling(unet2d_nuclei_broad_model, tmp_path)
163+
164+
165+
def test_predict_image_with_tiling_channel_last(stardist, tmp_path):
166+
_test_predict_image_with_tiling(stardist, tmp_path)
167+
168+
125169
def test_predict_images(unet2d_nuclei_broad_model, tmp_path):
126170
from bioimageio.core.prediction import predict_images
127171

0 commit comments

Comments
 (0)