1+ import collections
12import os
23import warnings
34from copy import deepcopy
45from itertools import product
56from pathlib import Path
7+ from typing import Dict , List , OrderedDict , Sequence , Tuple , Union
68
79import imageio
810import numpy as np
911import xarray as xr
1012
1113from bioimageio .core import load_resource_description
1214from bioimageio .core .resource_io .nodes import Model
13- from bioimageio .core .prediction_pipeline import create_prediction_pipeline
15+ from bioimageio .core .prediction_pipeline import PredictionPipeline , create_prediction_pipeline
1416from tqdm import tqdm
1517
1618
1719#
1820# utility functions for prediction
1921#
22+ from bioimageio .core .resource_io .nodes import ImplicitOutputShape , URI
2023
2124
2225def require_axes (im , axes ):
@@ -48,7 +51,7 @@ def require_axes(im, axes):
4851 return im
4952
5053
51- def pad (im , axes , padding , pad_right = True ):
54+ def pad (im , axes : Sequence [ str ] , padding , pad_right = True ) -> Tuple [ np . ndarray , Dict [ str , slice ]] :
5255 assert im .ndim == len (axes ), f"{ im .ndim } , { len (axes )} "
5356
5457 padding_ = deepcopy (padding )
@@ -82,10 +85,6 @@ def pad(im, axes, padding, pad_right=True):
8285
8386 pad_width .append ([0 , pwidth ] if pr else [pwidth , 0 ])
8487 crop [ax ] = slice (0 , dlen ) if pr else slice (pwidth , None )
85-
86- elif ax in "zyx" and mode == "fixed" :
87- pad_to = padding_ [ax ]
88-
8988 else :
9089 pad_width .append ([0 , 0 ])
9190 crop [ax ] = slice (None )
@@ -172,16 +171,15 @@ def get_tiling(shape, tile_shape, halo, input_axes):
172171 outer_tile ["c" ] = slice (None )
173172
174173 inner_tile = {
175- ax : slice (pos , min (pos + tsh , sh ))
176- for ax , pos , tsh , sh in zip (spatial_axes , positions , tile_shape_ , shape_ )
174+ ax : slice (pos , min (pos + tsh , sh )) for ax , pos , tsh , sh in zip (spatial_axes , positions , tile_shape_ , shape_ )
177175 }
178176 inner_tile ["b" ] = slice (None )
179177 inner_tile ["c" ] = slice (None )
180178
181179 local_tile = {
182180 ax : slice (
183181 inner_tile [ax ].start - outer_tile [ax ].start ,
184- - (outer_tile [ax ].stop - inner_tile [ax ].stop ) if outer_tile [ax ].stop != inner_tile [ax ].stop else None
182+ - (outer_tile [ax ].stop - inner_tile [ax ].stop ) if outer_tile [ax ].stop != inner_tile [ax ].stop else None ,
185183 )
186184 for ax in spatial_axes
187185 }
@@ -191,11 +189,31 @@ def get_tiling(shape, tile_shape, halo, input_axes):
191189 yield outer_tile , inner_tile , local_tile
192190
193191
194- def predict_with_tiling_impl (prediction_pipeline , input_ , output , tile_shape , halo , input_axes ):
195- assert input_ .ndim == len (input_axes ), f"{ input_ .ndim } , { len (input_axes )} "
192+ def predict_with_tiling_impl (
193+ prediction_pipeline ,
194+ inputs : List [xr .DataArray ],
195+ outputs : List [xr .DataArray ],
196+ tile_shapes : List [dict ],
197+ halos : List [dict ],
198+ ):
199+ if len (inputs ) > 1 :
200+ raise NotImplementedError ("Tiling with multiple inputs not implemented yet" )
201+
202+ if len (outputs ) > 1 :
203+ raise NotImplementedError ("Tiling with multiple outputs not implemented yet" )
204+
205+ assert len (tile_shapes ) == len (outputs )
206+ assert len (halos ) == len (outputs )
207+
208+ input_ = inputs [0 ]
209+ output = outputs [0 ]
210+ tile_shape = tile_shapes [0 ]
211+ halo = halos [0 ]
196212
197- input_ = xr .DataArray (input_ , dims = input_axes )
198- tiles = get_tiling (input_ .shape , tile_shape , halo , input_axes )
213+ tiles = get_tiling (shape = input_ .shape , tile_shape = tile_shape , halo = halo , input_axes = input_ .dims )
214+
215+ assert all (isinstance (ax , str ) for ax in input_ .dims )
216+ input_axes : Tuple [str , ...] = input_ .dims # noqa
199217
200218 def load_tile (tile ):
201219 inp = input_ [tile ]
@@ -211,75 +229,94 @@ def load_tile(tile):
211229 for outer_tile , inner_tile , local_tile in tiles :
212230 inp , pad_right = load_tile (outer_tile )
213231 out = predict_with_padding (prediction_pipeline , inp , padding , pad_right )
232+ assert len (out ) == 1
233+ out = out [0 ]
214234 output [inner_tile ] = out [local_tile ]
215235
216236
217237#
218238# prediction functions
219- # TODO support models with multiple in/outputs
220239#
221240
222241
223- def predict (prediction_pipeline , inputs ):
224- if isinstance (inputs , np . ndarray ):
242+ def predict (prediction_pipeline , inputs ) -> List [ xr . DataArray ] :
243+ if not isinstance (inputs , ( tuple , list ) ):
225244 inputs = [inputs ]
226- if len (inputs ) > 1 :
227- raise NotImplementedError (len (inputs ))
228- axes = tuple (prediction_pipeline .input_axes )
229- tagged_data = [xr .DataArray (ipt , dims = axes ) for ipt in inputs ]
245+
246+ tagged_data = [xr .DataArray (ipt , dims = axes ) for ipt , axes in zip (inputs , prediction_pipeline .input_axes )]
230247 return prediction_pipeline .forward (* tagged_data )
231248
232249
233- def predict_with_padding (prediction_pipeline , inputs , padding , pad_right = True ):
234- if isinstance (inputs , (np . ndarray , xr . DataArray )):
250+ def predict_with_padding (prediction_pipeline , inputs , padding , pad_right = True ) -> List [ xr . DataArray ] :
251+ if not isinstance (inputs , (tuple , list )):
235252 inputs = [inputs ]
236- axes = tuple (prediction_pipeline .input_axes )
237- inputs = [pad (inp , axes , padding , pad_right = pad_right ) for inp in inputs ]
238- inputs , crops = [inp [0 ] for inp in inputs ], [inp [1 ] for inp in inputs ]
253+
254+ assert len (inputs ) == len (prediction_pipeline .input_specs )
255+
256+ if not isinstance (padding , (tuple , list )):
257+ padding = [padding ]
258+
259+ assert len (padding ) == len (prediction_pipeline .input_specs )
260+ inputs , crops = zip (
261+ * [
262+ pad (inp , spec .axes , p , pad_right = pad_right )
263+ for inp , spec , p in zip (inputs , prediction_pipeline .input_specs , padding )
264+ ]
265+ )
266+
239267 result = predict (prediction_pipeline , inputs )
240- if isinstance (result , (list , tuple )):
241- result = [apply_crop (res , crop ) for res , crop in zip (result , crops )]
242- else :
243- result = apply_crop (result , crops [0 ])
244- return result
268+ return [apply_crop (res , crop ) for res , crop in zip (result , crops )]
245269
246270
247- def predict_with_tiling (prediction_pipeline , inputs , tiling ):
248- if isinstance (inputs , (list , tuple )):
249- if len (inputs ) > 1 :
250- raise NotImplementedError (len (inputs ))
251- input_ = inputs [0 ]
252- else :
253- input_ = inputs
254- input_axes = tuple (prediction_pipeline .input_axes )
271+ def predict_with_tiling (prediction_pipeline : PredictionPipeline , inputs , tiling ) -> List [xr .DataArray ]:
272+ if not isinstance (inputs , (list , tuple )):
273+ inputs = [inputs ]
255274
256- output_axes = tuple (prediction_pipeline .output_axes )
257- # NOTE there could also be models with a fixed output shape, but this is currently
258- # not reflected in prediction_pipeline, need to adapt this here once fixed
259- scale , offset = prediction_pipeline .scale , prediction_pipeline .offset
260- scale , offset = {sc [0 ]: sc [1 ] for sc in scale }, {off [0 ]: off [1 ] for off in offset }
275+ assert len (inputs ) == len (prediction_pipeline .input_specs )
276+ named_inputs : OrderedDict [str , xr .DataArray ] = collections .OrderedDict (
277+ ** {
278+ ipt_spec .name : xr .DataArray (ipt_data , dims = tuple (ipt_spec .axes ))
279+ for ipt_data , ipt_spec in zip (inputs , prediction_pipeline .input_specs )
280+ }
281+ )
261282
262- # for now, we only support tiling if the spatial shape doesn't change
263- # supporting this should not be so difficult, we would just need to apply the inverse
264- # to "out_shape = scale * in_shape + 2 * offset" ("in_shape = (out_shape - 2 * offset) / scale")
265- # to 'outer_tile' in 'get_tiling'
266- if any (scale [ax ] != 1 for ax in output_axes if ax in "xyz" ) or \
267- any (offset [ax ] != 0 for ax in output_axes if ax in "xyz" ):
268- raise NotImplementedError ("Tiling with a different output shape is not yet supported" )
283+ outputs = []
284+ for output_spec in prediction_pipeline .output_specs :
285+ if isinstance (output_spec .shape , ImplicitOutputShape ):
286+ scale = dict (zip (output_spec .axes , output_spec .shape .scale ))
287+ offset = dict (zip (output_spec .axes , output_spec .shape .offset ))
288+
289+ # for now, we only support tiling if the spatial shape doesn't change
290+ # supporting this should not be so difficult, we would just need to apply the inverse
291+ # to "out_shape = scale * in_shape + 2 * offset" ("in_shape = (out_shape - 2 * offset) / scale")
292+ # to 'outer_tile' in 'get_tiling'
293+ if any (sc != 1 for ax , sc in scale .items () if ax in "xyz" ) or any (
294+ off != 0 for ax , off in offset .items () if ax in "xyz"
295+ ):
296+ raise NotImplementedError ("Tiling with a different output shape is not yet supported" )
297+
298+ ref_input = named_inputs [output_spec .shape .reference_input ]
299+ ref_input_shape = dict (zip (ref_input .dims , ref_input .shape ))
300+ output_shape = tuple (int (scale [ax ] * ref_input_shape [ax ] + 2 * offset [ax ]) for ax in output_spec .axes )
301+ else :
302+ output_shape = tuple (output_spec .shape )
269303
270- out_shape = tuple (int (scale [ax ] * input_ .shape [input_axes .index (ax )] + 2 * offset [ax ]) for ax in output_axes )
271- # TODO the dtype information is missing from prediction pipeline
272- out_dtype = "float32"
273- output = xr .DataArray (np .zeros (out_shape , dtype = out_dtype ), dims = output_axes )
304+ outputs .append (xr .DataArray (np .zeros (output_shape , dtype = output_spec .data_type ), dims = tuple (output_spec .axes )))
274305
275- halo = tiling ["halo" ]
276- tile_shape = tiling ["tile" ]
306+ predict_with_tiling_impl (
307+ prediction_pipeline ,
308+ list (named_inputs .values ()),
309+ outputs ,
310+ tile_shapes = [tiling ["tile" ]], # todo: update tiling for multiple inputs/outputs
311+ halos = [tiling ["halo" ]],
312+ )
277313
278- predict_with_tiling_impl (prediction_pipeline , input_ , output , tile_shape , halo , input_axes )
279- return output
314+ return outputs
280315
281316
282317def parse_padding (padding , model ):
318+ if len (model .inputs ) > 1 :
319+ raise NotImplementedError ("Padding for multiple inputs not yet implemented" )
283320
284321 input_spec = model .inputs [0 ]
285322 pad_keys = tuple (input_spec .axes ) + ("mode" ,)
@@ -311,6 +348,12 @@ def check_padding(padding):
311348
312349
313350def parse_tiling (tiling , model ):
351+ if len (model .inputs ) > 1 :
352+ raise NotImplementedError ("Tiling for multiple inputs not yet implemented" )
353+
354+ if len (model .outputs ) > 1 :
355+ raise NotImplementedError ("Tiling for multiple outputs not yet implemented" )
356+
314357 input_spec = model .inputs [0 ]
315358 output_spec = model .outputs [0 ]
316359
@@ -347,19 +390,15 @@ def check_tiling(tiling):
347390 return tiling
348391
349392
350- # TODO support models with multiple in/outputs
351393def predict_image (model_rdf , inputs , outputs , padding = None , tiling = None , weight_format = None , devices = None ):
352394 """Run prediction for a single set of inputs with a bioimage.io model."""
353- if isinstance (inputs , (str , Path )):
395+ if not isinstance (inputs , (tuple , list )):
354396 inputs = [inputs ]
355- if len (inputs ) > 1 :
356- raise NotImplementedError (len (inputs ))
357- if isinstance (outputs , (str , Path )):
397+
398+ if not isinstance (outputs , (tuple , list )):
358399 outputs = [outputs ]
359- if len (outputs ) > 1 :
360- raise NotImplementedError (len (outputs ))
361400
362- model = load_resource_description (Path ( model_rdf ) )
401+ model = load_resource_description (model_rdf )
363402 assert isinstance (model , Model )
364403 if len (model .inputs ) != len (inputs ):
365404 raise ValueError
@@ -369,53 +408,43 @@ def predict_image(model_rdf, inputs, outputs, padding=None, tiling=None, weight_
369408 prediction_pipeline = create_prediction_pipeline (
370409 bioimageio_model = model , weight_format = weight_format , devices = devices
371410 )
372- axes = tuple (prediction_pipeline .input_axes )
373411
374412 padding = parse_padding (padding , model )
375413 tiling = parse_tiling (tiling , model )
376414 if padding is not None and tiling is not None :
377415 raise ValueError ("Only one of padding or tiling is supported" )
378416
379- input_data = [load_image (inp , axes ) for inp in inputs ]
417+ input_data = [load_image (inp , axes ) for inp , axes in zip ( inputs , prediction_pipeline . input_axes ) ]
380418 if padding is not None :
381419 result = predict_with_padding (prediction_pipeline , input_data , padding )
382420 elif tiling is not None :
383421 result = predict_with_tiling (prediction_pipeline , input_data , tiling )
384422 else :
385423 result = predict (prediction_pipeline , input_data )
386424
387- if isinstance (result , list ):
388- assert len (result ) == len (outputs )
389- for res , out in zip (result , outputs ):
390- save_image (out , res )
391- else :
392- assert len (outputs ) == 1
393- save_image (outputs [0 ], result )
425+ assert isinstance (result , list )
426+ assert len (result ) == len (outputs )
427+ for res , out in zip (result , outputs ):
428+ save_image (out , res )
394429
395430
396431def predict_images (
397432 model_rdf ,
398- inputs ,
399- outputs ,
433+ inputs : Sequence [ Union [ Tuple [ Path , ...], List [ Path ], Path ]] ,
434+ outputs : Sequence [ Union [ Tuple [ Path , ...], List [ Path ], Path ]] ,
400435 padding = None ,
401436 tiling = None ,
402437 weight_format = None ,
403438 devices = None ,
404- verbose = False
439+ verbose = False ,
405440):
406- """Predict multiple inputs with a bioimage.io model.
407-
408- Only works for models with a single input and output tensor.
409- """
410- model = load_resource_description (Path (model_rdf ))
441+ """Predict multiple inputs with a bioimage.io model."""
442+ model = load_resource_description (model_rdf )
411443 assert isinstance (model , Model )
412- if len (model .inputs ) > 1 or len (model .outputs ) > 1 :
413- raise ValueError ("predict_images only supports models that have a single input/output tensor" )
414444
415445 prediction_pipeline = create_prediction_pipeline (
416446 bioimageio_model = model , weight_format = weight_format , devices = devices
417447 )
418- axes = tuple (prediction_pipeline .input_axes )
419448
420449 padding = parse_padding (padding , model )
421450 tiling = parse_tiling (tiling , model )
@@ -427,32 +456,42 @@ def predict_images(
427456 prog = tqdm (prog , total = len (inputs ))
428457
429458 for inp , outp in prog :
430- inp = load_image (inp , axes )
459+ if not isinstance (inp , (tuple , list )):
460+ inp = [inp ]
461+
462+ if not isinstance (outp , (tuple , list )):
463+ outp = [outp ]
464+
465+ inp = [load_image (im , sp .axes ) for im , sp in zip (inp , prediction_pipeline .input_specs )]
431466 if padding is not None :
432467 res = predict_with_padding (prediction_pipeline , inp , padding )
433468 elif tiling is not None :
434469 res = predict_with_tiling (prediction_pipeline , inp , tiling )
435470 else :
436471 res = predict (prediction_pipeline , inp )
437- save_image (outp , res )
472+
473+ assert isinstance (res , list )
474+ for out , r in zip (outp , res ):
475+ save_image (out , r )
438476
439477
440- def test_model (model_rdf , weight_format = None , devices = None , decimal = 4 ):
478+ def test_model (model_rdf : Union [ URI , Path , str ] , weight_format = None , devices = None , decimal = 4 ):
441479 """Test whether the test output(s) of a model can be reproduced.
442480
443481 Returns True if the test passes, otherwise returns False and issues a warning.
444482 """
445- model = load_resource_description (Path (model_rdf ))
483+ print (model_rdf , Path (model_rdf ).exists ())
484+ model = load_resource_description (model_rdf )
446485 assert isinstance (model , Model )
447486 prediction_pipeline = create_prediction_pipeline (
448487 bioimageio_model = model , devices = devices , weight_format = weight_format
449488 )
450- inputs = [np .load (in_path ) for in_path in model .test_inputs ]
489+ inputs = [np .load (str ( in_path ) ) for in_path in model .test_inputs ]
451490 results = predict (prediction_pipeline , inputs )
452491 if isinstance (results , (np .ndarray , xr .DataArray )):
453492 results = [results ]
454493
455- expected = [np .load (out_path ) for out_path in model .test_outputs ]
494+ expected = [np .load (str ( out_path ) ) for out_path in model .test_outputs ]
456495 if len (results ) != len (expected ):
457496 warnings .warn (f"Number of outputs and number of expected outputs disagree: { len (results )} != { len (expected )} " )
458497 return False
0 commit comments