Skip to content

Commit 5d02803

Browse files
authored
Merge pull request #103 from bioimage-io/multiple_tensors
Add support for multiple tensor input/output
2 parents 3e5d14a + d456ff7 commit 5d02803

File tree

13 files changed

+481
-346
lines changed

13 files changed

+481
-346
lines changed

bioimageio/core/prediction.py

Lines changed: 131 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,25 @@
1+
import collections
12
import os
23
import warnings
34
from copy import deepcopy
45
from itertools import product
56
from pathlib import Path
7+
from typing import Dict, List, OrderedDict, Sequence, Tuple, Union
68

79
import imageio
810
import numpy as np
911
import xarray as xr
1012

1113
from bioimageio.core import load_resource_description
1214
from bioimageio.core.resource_io.nodes import Model
13-
from bioimageio.core.prediction_pipeline import create_prediction_pipeline
15+
from bioimageio.core.prediction_pipeline import PredictionPipeline, create_prediction_pipeline
1416
from tqdm import tqdm
1517

1618

1719
#
1820
# utility functions for prediction
1921
#
22+
from bioimageio.core.resource_io.nodes import ImplicitOutputShape, URI
2023

2124

2225
def require_axes(im, axes):
@@ -48,7 +51,7 @@ def require_axes(im, axes):
4851
return im
4952

5053

51-
def pad(im, axes, padding, pad_right=True):
54+
def pad(im, axes: Sequence[str], padding, pad_right=True) -> Tuple[np.ndarray, Dict[str, slice]]:
5255
assert im.ndim == len(axes), f"{im.ndim}, {len(axes)}"
5356

5457
padding_ = deepcopy(padding)
@@ -82,10 +85,6 @@ def pad(im, axes, padding, pad_right=True):
8285

8386
pad_width.append([0, pwidth] if pr else [pwidth, 0])
8487
crop[ax] = slice(0, dlen) if pr else slice(pwidth, None)
85-
86-
elif ax in "zyx" and mode == "fixed":
87-
pad_to = padding_[ax]
88-
8988
else:
9089
pad_width.append([0, 0])
9190
crop[ax] = slice(None)
@@ -172,16 +171,15 @@ def get_tiling(shape, tile_shape, halo, input_axes):
172171
outer_tile["c"] = slice(None)
173172

174173
inner_tile = {
175-
ax: slice(pos, min(pos + tsh, sh))
176-
for ax, pos, tsh, sh in zip(spatial_axes, positions, tile_shape_, shape_)
174+
ax: slice(pos, min(pos + tsh, sh)) for ax, pos, tsh, sh in zip(spatial_axes, positions, tile_shape_, shape_)
177175
}
178176
inner_tile["b"] = slice(None)
179177
inner_tile["c"] = slice(None)
180178

181179
local_tile = {
182180
ax: slice(
183181
inner_tile[ax].start - outer_tile[ax].start,
184-
-(outer_tile[ax].stop - inner_tile[ax].stop) if outer_tile[ax].stop != inner_tile[ax].stop else None
182+
-(outer_tile[ax].stop - inner_tile[ax].stop) if outer_tile[ax].stop != inner_tile[ax].stop else None,
185183
)
186184
for ax in spatial_axes
187185
}
@@ -191,11 +189,31 @@ def get_tiling(shape, tile_shape, halo, input_axes):
191189
yield outer_tile, inner_tile, local_tile
192190

193191

194-
def predict_with_tiling_impl(prediction_pipeline, input_, output, tile_shape, halo, input_axes):
195-
assert input_.ndim == len(input_axes), f"{input_.ndim}, {len(input_axes)}"
192+
def predict_with_tiling_impl(
193+
prediction_pipeline,
194+
inputs: List[xr.DataArray],
195+
outputs: List[xr.DataArray],
196+
tile_shapes: List[dict],
197+
halos: List[dict],
198+
):
199+
if len(inputs) > 1:
200+
raise NotImplementedError("Tiling with multiple inputs not implemented yet")
201+
202+
if len(outputs) > 1:
203+
raise NotImplementedError("Tiling with multiple outputs not implemented yet")
204+
205+
assert len(tile_shapes) == len(outputs)
206+
assert len(halos) == len(outputs)
207+
208+
input_ = inputs[0]
209+
output = outputs[0]
210+
tile_shape = tile_shapes[0]
211+
halo = halos[0]
196212

197-
input_ = xr.DataArray(input_, dims=input_axes)
198-
tiles = get_tiling(input_.shape, tile_shape, halo, input_axes)
213+
tiles = get_tiling(shape=input_.shape, tile_shape=tile_shape, halo=halo, input_axes=input_.dims)
214+
215+
assert all(isinstance(ax, str) for ax in input_.dims)
216+
input_axes: Tuple[str, ...] = input_.dims # noqa
199217

200218
def load_tile(tile):
201219
inp = input_[tile]
@@ -211,75 +229,94 @@ def load_tile(tile):
211229
for outer_tile, inner_tile, local_tile in tiles:
212230
inp, pad_right = load_tile(outer_tile)
213231
out = predict_with_padding(prediction_pipeline, inp, padding, pad_right)
232+
assert len(out) == 1
233+
out = out[0]
214234
output[inner_tile] = out[local_tile]
215235

216236

217237
#
218238
# prediction functions
219-
# TODO support models with multiple in/outputs
220239
#
221240

222241

223-
def predict(prediction_pipeline, inputs):
224-
if isinstance(inputs, np.ndarray):
242+
def predict(prediction_pipeline, inputs) -> List[xr.DataArray]:
243+
if not isinstance(inputs, (tuple, list)):
225244
inputs = [inputs]
226-
if len(inputs) > 1:
227-
raise NotImplementedError(len(inputs))
228-
axes = tuple(prediction_pipeline.input_axes)
229-
tagged_data = [xr.DataArray(ipt, dims=axes) for ipt in inputs]
245+
246+
tagged_data = [xr.DataArray(ipt, dims=axes) for ipt, axes in zip(inputs, prediction_pipeline.input_axes)]
230247
return prediction_pipeline.forward(*tagged_data)
231248

232249

233-
def predict_with_padding(prediction_pipeline, inputs, padding, pad_right=True):
234-
if isinstance(inputs, (np.ndarray, xr.DataArray)):
250+
def predict_with_padding(prediction_pipeline, inputs, padding, pad_right=True) -> List[xr.DataArray]:
251+
if not isinstance(inputs, (tuple, list)):
235252
inputs = [inputs]
236-
axes = tuple(prediction_pipeline.input_axes)
237-
inputs = [pad(inp, axes, padding, pad_right=pad_right) for inp in inputs]
238-
inputs, crops = [inp[0] for inp in inputs], [inp[1] for inp in inputs]
253+
254+
assert len(inputs) == len(prediction_pipeline.input_specs)
255+
256+
if not isinstance(padding, (tuple, list)):
257+
padding = [padding]
258+
259+
assert len(padding) == len(prediction_pipeline.input_specs)
260+
inputs, crops = zip(
261+
*[
262+
pad(inp, spec.axes, p, pad_right=pad_right)
263+
for inp, spec, p in zip(inputs, prediction_pipeline.input_specs, padding)
264+
]
265+
)
266+
239267
result = predict(prediction_pipeline, inputs)
240-
if isinstance(result, (list, tuple)):
241-
result = [apply_crop(res, crop) for res, crop in zip(result, crops)]
242-
else:
243-
result = apply_crop(result, crops[0])
244-
return result
268+
return [apply_crop(res, crop) for res, crop in zip(result, crops)]
245269

246270

247-
def predict_with_tiling(prediction_pipeline, inputs, tiling):
248-
if isinstance(inputs, (list, tuple)):
249-
if len(inputs) > 1:
250-
raise NotImplementedError(len(inputs))
251-
input_ = inputs[0]
252-
else:
253-
input_ = inputs
254-
input_axes = tuple(prediction_pipeline.input_axes)
271+
def predict_with_tiling(prediction_pipeline: PredictionPipeline, inputs, tiling) -> List[xr.DataArray]:
272+
if not isinstance(inputs, (list, tuple)):
273+
inputs = [inputs]
255274

256-
output_axes = tuple(prediction_pipeline.output_axes)
257-
# NOTE there could also be models with a fixed output shape, but this is currently
258-
# not reflected in prediction_pipeline, need to adapt this here once fixed
259-
scale, offset = prediction_pipeline.scale, prediction_pipeline.offset
260-
scale, offset = {sc[0]: sc[1] for sc in scale}, {off[0]: off[1] for off in offset}
275+
assert len(inputs) == len(prediction_pipeline.input_specs)
276+
named_inputs: OrderedDict[str, xr.DataArray] = collections.OrderedDict(
277+
**{
278+
ipt_spec.name: xr.DataArray(ipt_data, dims=tuple(ipt_spec.axes))
279+
for ipt_data, ipt_spec in zip(inputs, prediction_pipeline.input_specs)
280+
}
281+
)
261282

262-
# for now, we only support tiling if the spatial shape doesn't change
263-
# supporting this should not be so difficult, we would just need to apply the inverse
264-
# to "out_shape = scale * in_shape + 2 * offset" ("in_shape = (out_shape - 2 * offset) / scale")
265-
# to 'outer_tile' in 'get_tiling'
266-
if any(scale[ax] != 1 for ax in output_axes if ax in "xyz") or\
267-
any(offset[ax] != 0 for ax in output_axes if ax in "xyz"):
268-
raise NotImplementedError("Tiling with a different output shape is not yet supported")
283+
outputs = []
284+
for output_spec in prediction_pipeline.output_specs:
285+
if isinstance(output_spec.shape, ImplicitOutputShape):
286+
scale = dict(zip(output_spec.axes, output_spec.shape.scale))
287+
offset = dict(zip(output_spec.axes, output_spec.shape.offset))
288+
289+
# for now, we only support tiling if the spatial shape doesn't change
290+
# supporting this should not be so difficult, we would just need to apply the inverse
291+
# to "out_shape = scale * in_shape + 2 * offset" ("in_shape = (out_shape - 2 * offset) / scale")
292+
# to 'outer_tile' in 'get_tiling'
293+
if any(sc != 1 for ax, sc in scale.items() if ax in "xyz") or any(
294+
off != 0 for ax, off in offset.items() if ax in "xyz"
295+
):
296+
raise NotImplementedError("Tiling with a different output shape is not yet supported")
297+
298+
ref_input = named_inputs[output_spec.shape.reference_input]
299+
ref_input_shape = dict(zip(ref_input.dims, ref_input.shape))
300+
output_shape = tuple(int(scale[ax] * ref_input_shape[ax] + 2 * offset[ax]) for ax in output_spec.axes)
301+
else:
302+
output_shape = tuple(output_spec.shape)
269303

270-
out_shape = tuple(int(scale[ax] * input_.shape[input_axes.index(ax)] + 2 * offset[ax]) for ax in output_axes)
271-
# TODO the dtype information is missing from prediction pipeline
272-
out_dtype = "float32"
273-
output = xr.DataArray(np.zeros(out_shape, dtype=out_dtype), dims=output_axes)
304+
outputs.append(xr.DataArray(np.zeros(output_shape, dtype=output_spec.data_type), dims=tuple(output_spec.axes)))
274305

275-
halo = tiling["halo"]
276-
tile_shape = tiling["tile"]
306+
predict_with_tiling_impl(
307+
prediction_pipeline,
308+
list(named_inputs.values()),
309+
outputs,
310+
tile_shapes=[tiling["tile"]], # todo: update tiling for multiple inputs/outputs
311+
halos=[tiling["halo"]],
312+
)
277313

278-
predict_with_tiling_impl(prediction_pipeline, input_, output, tile_shape, halo, input_axes)
279-
return output
314+
return outputs
280315

281316

282317
def parse_padding(padding, model):
318+
if len(model.inputs) > 1:
319+
raise NotImplementedError("Padding for multiple inputs not yet implemented")
283320

284321
input_spec = model.inputs[0]
285322
pad_keys = tuple(input_spec.axes) + ("mode",)
@@ -311,6 +348,12 @@ def check_padding(padding):
311348

312349

313350
def parse_tiling(tiling, model):
351+
if len(model.inputs) > 1:
352+
raise NotImplementedError("Tiling for multiple inputs not yet implemented")
353+
354+
if len(model.outputs) > 1:
355+
raise NotImplementedError("Tiling for multiple outputs not yet implemented")
356+
314357
input_spec = model.inputs[0]
315358
output_spec = model.outputs[0]
316359

@@ -347,19 +390,15 @@ def check_tiling(tiling):
347390
return tiling
348391

349392

350-
# TODO support models with multiple in/outputs
351393
def predict_image(model_rdf, inputs, outputs, padding=None, tiling=None, weight_format=None, devices=None):
352394
"""Run prediction for a single set of inputs with a bioimage.io model."""
353-
if isinstance(inputs, (str, Path)):
395+
if not isinstance(inputs, (tuple, list)):
354396
inputs = [inputs]
355-
if len(inputs) > 1:
356-
raise NotImplementedError(len(inputs))
357-
if isinstance(outputs, (str, Path)):
397+
398+
if not isinstance(outputs, (tuple, list)):
358399
outputs = [outputs]
359-
if len(outputs) > 1:
360-
raise NotImplementedError(len(outputs))
361400

362-
model = load_resource_description(Path(model_rdf))
401+
model = load_resource_description(model_rdf)
363402
assert isinstance(model, Model)
364403
if len(model.inputs) != len(inputs):
365404
raise ValueError
@@ -369,53 +408,43 @@ def predict_image(model_rdf, inputs, outputs, padding=None, tiling=None, weight_
369408
prediction_pipeline = create_prediction_pipeline(
370409
bioimageio_model=model, weight_format=weight_format, devices=devices
371410
)
372-
axes = tuple(prediction_pipeline.input_axes)
373411

374412
padding = parse_padding(padding, model)
375413
tiling = parse_tiling(tiling, model)
376414
if padding is not None and tiling is not None:
377415
raise ValueError("Only one of padding or tiling is supported")
378416

379-
input_data = [load_image(inp, axes) for inp in inputs]
417+
input_data = [load_image(inp, axes) for inp, axes in zip(inputs, prediction_pipeline.input_axes)]
380418
if padding is not None:
381419
result = predict_with_padding(prediction_pipeline, input_data, padding)
382420
elif tiling is not None:
383421
result = predict_with_tiling(prediction_pipeline, input_data, tiling)
384422
else:
385423
result = predict(prediction_pipeline, input_data)
386424

387-
if isinstance(result, list):
388-
assert len(result) == len(outputs)
389-
for res, out in zip(result, outputs):
390-
save_image(out, res)
391-
else:
392-
assert len(outputs) == 1
393-
save_image(outputs[0], result)
425+
assert isinstance(result, list)
426+
assert len(result) == len(outputs)
427+
for res, out in zip(result, outputs):
428+
save_image(out, res)
394429

395430

396431
def predict_images(
397432
model_rdf,
398-
inputs,
399-
outputs,
433+
inputs: Sequence[Union[Tuple[Path, ...], List[Path], Path]],
434+
outputs: Sequence[Union[Tuple[Path, ...], List[Path], Path]],
400435
padding=None,
401436
tiling=None,
402437
weight_format=None,
403438
devices=None,
404-
verbose=False
439+
verbose=False,
405440
):
406-
"""Predict multiple inputs with a bioimage.io model.
407-
408-
Only works for models with a single input and output tensor.
409-
"""
410-
model = load_resource_description(Path(model_rdf))
441+
"""Predict multiple inputs with a bioimage.io model."""
442+
model = load_resource_description(model_rdf)
411443
assert isinstance(model, Model)
412-
if len(model.inputs) > 1 or len(model.outputs) > 1:
413-
raise ValueError("predict_images only supports models that have a single input/output tensor")
414444

415445
prediction_pipeline = create_prediction_pipeline(
416446
bioimageio_model=model, weight_format=weight_format, devices=devices
417447
)
418-
axes = tuple(prediction_pipeline.input_axes)
419448

420449
padding = parse_padding(padding, model)
421450
tiling = parse_tiling(tiling, model)
@@ -427,32 +456,42 @@ def predict_images(
427456
prog = tqdm(prog, total=len(inputs))
428457

429458
for inp, outp in prog:
430-
inp = load_image(inp, axes)
459+
if not isinstance(inp, (tuple, list)):
460+
inp = [inp]
461+
462+
if not isinstance(outp, (tuple, list)):
463+
outp = [outp]
464+
465+
inp = [load_image(im, sp.axes) for im, sp in zip(inp, prediction_pipeline.input_specs)]
431466
if padding is not None:
432467
res = predict_with_padding(prediction_pipeline, inp, padding)
433468
elif tiling is not None:
434469
res = predict_with_tiling(prediction_pipeline, inp, tiling)
435470
else:
436471
res = predict(prediction_pipeline, inp)
437-
save_image(outp, res)
472+
473+
assert isinstance(res, list)
474+
for out, r in zip(outp, res):
475+
save_image(out, r)
438476

439477

440-
def test_model(model_rdf, weight_format=None, devices=None, decimal=4):
478+
def test_model(model_rdf: Union[URI, Path, str], weight_format=None, devices=None, decimal=4):
441479
"""Test whether the test output(s) of a model can be reproduced.
442480
443481
Returns True if the test passes, otherwise returns False and issues a warning.
444482
"""
445-
model = load_resource_description(Path(model_rdf))
483+
print(model_rdf, Path(model_rdf).exists())
484+
model = load_resource_description(model_rdf)
446485
assert isinstance(model, Model)
447486
prediction_pipeline = create_prediction_pipeline(
448487
bioimageio_model=model, devices=devices, weight_format=weight_format
449488
)
450-
inputs = [np.load(in_path) for in_path in model.test_inputs]
489+
inputs = [np.load(str(in_path)) for in_path in model.test_inputs]
451490
results = predict(prediction_pipeline, inputs)
452491
if isinstance(results, (np.ndarray, xr.DataArray)):
453492
results = [results]
454493

455-
expected = [np.load(out_path) for out_path in model.test_outputs]
494+
expected = [np.load(str(out_path)) for out_path in model.test_outputs]
456495
if len(results) != len(expected):
457496
warnings.warn(f"Number of outputs and number of expected outputs disagree: {len(results)} != {len(expected)}")
458497
return False

0 commit comments

Comments
 (0)