11import collections
22import os
3- from copy import deepcopy
43from itertools import product
54from pathlib import Path
65from typing import Dict , Iterator , List , NamedTuple , Optional , OrderedDict , Sequence , Tuple , Union
76
8- import imageio
97import numpy as np
108import xarray as xr
119from tqdm import tqdm
1210
11+ from bioimageio .core import image_helper
1312from bioimageio .core import load_resource_description
1413from bioimageio .core .prediction_pipeline import PredictionPipeline , create_prediction_pipeline
15- from bioimageio .core .resource_io .nodes import ImplicitOutputShape , InputTensor , Model , ResourceDescription , OutputTensor
14+ from bioimageio .core .resource_io .nodes import ImplicitOutputShape , Model , ResourceDescription
1615from bioimageio .spec .shared import raw_nodes
1716from bioimageio .spec .shared .raw_nodes import ResourceDescription as RawResourceDescription
1817
1918
20- #
21- # utility functions for prediction
22- #
23- def _require_axes (im , axes ):
24- is_volume = "z" in axes
25- # we assume images / volumes are loaded as one of
26- # yx, yxc, zyxc
27- if im .ndim == 2 :
28- im_axes = ("y" , "x" )
29- elif im .ndim == 3 :
30- im_axes = ("z" , "y" , "x" ) if is_volume else ("y" , "x" , "c" )
31- elif im .ndim == 4 :
32- raise NotImplementedError
33- else : # ndim >= 5 not implemented
34- raise RuntimeError
35-
36- # add singleton channel dimension if not present
37- if "c" not in im_axes :
38- im = im [..., None ]
39- im_axes = im_axes + ("c" ,)
40-
41- # add singleton batch dim
42- im = im [None ]
43- im_axes = ("b" ,) + im_axes
44-
45- # permute the axes correctly
46- assert set (axes ) == set (im_axes )
47- axes_permutation = tuple (im_axes .index (ax ) for ax in axes )
48- im = im .transpose (axes_permutation )
49- return im
50-
51-
52- def _pad (im , axes : Sequence [str ], padding , pad_right = True ) -> Tuple [np .ndarray , Dict [str , slice ]]:
53- assert im .ndim == len (axes ), f"{ im .ndim } , { len (axes )} "
54-
55- padding_ = deepcopy (padding )
56- mode = padding_ .pop ("mode" , "dynamic" )
57- assert mode in ("dynamic" , "fixed" )
58-
59- is_volume = "z" in axes
60- if is_volume :
61- assert len (padding_ ) == 3
62- else :
63- assert len (padding_ ) == 2
64-
65- if isinstance (pad_right , bool ):
66- pad_right = len (axes ) * [pad_right ]
67-
68- pad_width = []
69- crop = {}
70- for ax , dlen , pr in zip (axes , im .shape , pad_right ):
71-
72- if ax in "zyx" :
73- pad_to = padding_ [ax ]
74-
75- if mode == "dynamic" :
76- r = dlen % pad_to
77- pwidth = 0 if r == 0 else (pad_to - r )
78- else :
79- if pad_to < dlen :
80- msg = f"Padding for axis { ax } failed; pad shape { pad_to } is smaller than the image shape { dlen } ."
81- raise RuntimeError (msg )
82- pwidth = pad_to - dlen
83-
84- pad_width .append ([0 , pwidth ] if pr else [pwidth , 0 ])
85- crop [ax ] = slice (0 , dlen ) if pr else slice (pwidth , None )
86- else :
87- pad_width .append ([0 , 0 ])
88- crop [ax ] = slice (None )
89-
90- im = np .pad (im , pad_width , mode = "symmetric" )
91- return im , crop
92-
93-
94- def _load_image (in_path , axes : Sequence [str ]) -> xr .DataArray :
95- ext = os .path .splitext (in_path )[1 ]
96- if ext == ".npy" :
97- im = np .load (in_path )
98- else :
99- is_volume = "z" in axes
100- im = imageio .volread (in_path ) if is_volume else imageio .imread (in_path )
101- im = _require_axes (im , axes )
102- return xr .DataArray (im , dims = axes )
103-
104-
105- def _load_tensors (sources , tensor_specs : List [Union [InputTensor , OutputTensor ]]) -> List [xr .DataArray ]:
106- return [_load_image (s , sspec .axes ) for s , sspec in zip (sources , tensor_specs )]
107-
108-
109- def _to_channel_last (image ):
110- chan_id = image .dims .index ("c" )
111- if chan_id != image .ndim - 1 :
112- target_axes = tuple (ax for ax in image .dims if ax != "c" ) + ("c" ,)
113- image = image .transpose (* target_axes )
114- return image
115-
116-
117- def _save_image (out_path , image ):
118- ext = os .path .splitext (out_path )[1 ]
119- if ext == ".npy" :
120- np .save (out_path , image )
121- else :
122- is_volume = "z" in image .dims
123-
124- # squeeze batch or channel axes if they are singletons
125- squeeze = {ax : 0 if (ax in "bc" and sh == 1 ) else slice (None ) for ax , sh in zip (image .dims , image .shape )}
126- image = image [squeeze ]
127-
128- if "b" in image .dims :
129- raise RuntimeError (f"Cannot save prediction with batchsize > 1 as { ext } -file" )
130- if "c" in image .dims : # image formats need channel last
131- image = _to_channel_last (image )
132-
133- save_function = imageio .volsave if is_volume else imageio .imsave
134- # most image formats only support channel dimensions of 1, 3 or 4;
135- # if not we need to save the channels separately
136- ndim = 3 if is_volume else 2
137- save_as_single_image = image .ndim == ndim or (image .shape [- 1 ] in (3 , 4 ))
138-
139- if save_as_single_image :
140- save_function (out_path , image )
141- else :
142- out_prefix , ext = os .path .splitext (out_path )
143- for c in range (image .shape [- 1 ]):
144- chan_out_path = f"{ out_prefix } -c{ c } { ext } "
145- save_function (chan_out_path , image [..., c ])
146-
147-
14819def _apply_crop (data , crop ):
14920 crop = tuple (crop [ax ] for ax in data .dims )
15021 return data [crop ]
@@ -345,7 +216,7 @@ def predict_with_padding(
345216 assert len (padding ) == len (prediction_pipeline .input_specs )
346217 inputs , crops = zip (
347218 * [
348- _pad (inp , spec .axes , p , pad_right = pad_right )
219+ image_helper . pad (inp , spec .axes , p , pad_right = pad_right )
349220 for inp , spec , p in zip (inputs , prediction_pipeline .input_specs , padding )
350221 ]
351222 )
@@ -508,7 +379,7 @@ def _predict_sample(prediction_pipeline, inputs, outputs, padding, tiling):
508379 if padding and tiling :
509380 raise ValueError ("Only one of padding or tiling is supported" )
510381
511- input_data = _load_tensors (inputs , prediction_pipeline .input_specs )
382+ input_data = image_helper . load_tensors (inputs , prediction_pipeline .input_specs )
512383 if padding is not None :
513384 result = predict_with_padding (prediction_pipeline , input_data , padding )
514385 elif tiling is not None :
@@ -519,7 +390,7 @@ def _predict_sample(prediction_pipeline, inputs, outputs, padding, tiling):
519390 assert isinstance (result , list )
520391 assert len (result ) == len (outputs )
521392 for res , out in zip (result , outputs ):
522- _save_image (out , res )
393+ image_helper . save_image (out , res )
523394
524395
525396def predict_image (
0 commit comments