55from shutil import copyfile
66from typing import Any , Dict , List , Optional , Tuple , Union
77
8+ import imageio
89import numpy as np
910import requests
1011
@@ -172,14 +173,6 @@ def _get_data_range(data_range, dtype):
172173 return data_range
173174
174175
175- def _get_axes (axes , ndim ):
176- if axes is None :
177- assert ndim in (2 , 4 , 5 )
178- default_axes = {2 : "bc" , 4 : "bcyx" , 5 : "bczyx" }
179- axes = default_axes [ndim ]
180- return axes
181-
182-
183176def _get_input_tensor (path , name , step , min_shape , data_range , axes , preprocessing ):
184177 test_in = np .load (path )
185178 shape = test_in .shape
@@ -189,9 +182,7 @@ def _get_input_tensor(path, name, step, min_shape, data_range, axes, preprocessi
189182 else :
190183 shape_description = {"min" : shape if min_shape is None else min_shape , "step" : step }
191184
192- axes = _get_axes (axes , test_in .ndim )
193185 data_range = _get_data_range (data_range , test_in .dtype )
194-
195186 kwargs = {}
196187 if preprocessing is not None :
197188 kwargs ["preprocessing" ] = [{"name" : k , "kwargs" : v } for k , v in preprocessing .items ()]
@@ -219,9 +210,7 @@ def _get_output_tensor(path, name, reference_tensor, scale, offset, axes, data_r
219210 assert offset is not None
220211 shape_description = {"reference_tensor" : reference_tensor , "scale" : scale , "offset" : offset }
221212
222- axes = _get_axes (axes , test_out .ndim )
223213 data_range = _get_data_range (data_range , test_out .dtype )
224-
225214 kwargs = {}
226215 if postprocessing is not None :
227216 kwargs ["postprocessing" ] = [{"name" : k , "kwargs" : v } for k , v in postprocessing .items ()]
@@ -417,6 +406,106 @@ def write_im(path, im, axes):
417406 return [Path (p .name ) for p in sample_in_paths ], [Path (p .name ) for p in sample_out_paths ]
418407
419408
409+ # create better cover images for 3d data and non-image outputs
410+ def _generate_covers (in_path , out_path , input_axes , output_axes , root ):
411+
412+ def normalize (data , axis , eps = 1e-7 ):
413+ data = data .astype ('float32' )
414+ data -= data .min (axis = axis , keepdims = True )
415+ data /= (data .max (axis = axis , keepdims = True ) + eps )
416+ return data
417+
418+ def to_image (data , data_axes ):
419+ assert data .ndim in (4 , 5 )
420+
421+ # transpose the data to "bczyx" / "bcyx" order
422+ axes = "bczyx" if data .ndim == 5 else "bcyx"
423+ assert set (data_axes ) == set (axes )
424+ if axes != data_axes :
425+ ax_permutation = tuple (data_axes .index (ax ) for ax in axes )
426+ data = data .transpose (ax_permutation )
427+
428+ # select single image with channels from the data
429+ if data .ndim == 5 :
430+ z0 = data .shape [2 ] // 2
431+ data = data [0 , :, z0 ]
432+ else :
433+ data = data [0 , :]
434+
435+ # normalize the data and map to 8 bit
436+ data = normalize (data , axis = (1 , 2 ))
437+ data = (data * 255 ).astype ("uint8" )
438+ return data
439+
440+ cover_path = os .path .join (root , "cover.png" )
441+ input_ , output = np .load (in_path ), np .load (out_path )
442+
443+ input_ = to_image (input_ , input_axes )
444+ # this is not image data so we only save the input image
445+ if output .ndim < 4 :
446+ imageio .imwrite (cover_path , input_ .transpose ((1 , 2 , 0 )))
447+ return [_ensure_local (cover_path , root )]
448+ output = to_image (output , output_axes )
449+
450+ chan_in = input_ .shape [0 ]
451+ # make sure the input is rgb
452+ if chan_in == 1 : # single channel -> repeat it 3 times
453+ input_ = np .repeat (input_ , 3 , axis = 0 )
454+ elif chan_in != 3 : # != 3 channels -> take first channe and repeat it 3 times
455+ input_ = np .repeat (input_ [0 :1 ], 3 , axis = 0 )
456+
457+ im_shape = input_ .shape [1 :]
458+ if im_shape != output .shape [1 :]: # just return the input image if shapes don"t agree
459+ return input_
460+
461+ def diagonal_split (im0 , im1 ):
462+ assert im0 .shape [0 ] == im1 .shape [0 ] == 3
463+ n , m = im_shape
464+ out = np .ones ((3 , n , m ), dtype = "uint8" )
465+ for c in range (3 ):
466+ outc = np .tril (im0 [c ])
467+ mask = outc == 0
468+ outc [mask ] = np .triu (im1 [c ])[mask ]
469+ out [c ] = outc
470+ return out
471+
472+ def grid_im (im0 , im1 ):
473+ ims_per_row = 3
474+ n_chan = im1 .shape [0 ]
475+ n_images = n_chan + 1
476+ n_rows = int (np .ceil (float (n_images ) / ims_per_row ))
477+
478+ n , m = im_shape
479+ x , y = ims_per_row * n , n_rows * m
480+ out = np .zeros ((3 , y , x ))
481+ images = [im0 ] + [np .repeat (im1 [i :i + 1 ], 3 , axis = 0 ) for i in range (n_chan )]
482+
483+ i , j = 0 , 0
484+ for im in images :
485+ x0 , x1 = i * n , (i + 1 ) * n
486+ y0 , y1 = j * m , (j + 1 ) * m
487+ out [:, y0 :y1 , x0 :x1 ] = im
488+
489+ i += 1
490+ if i == ims_per_row :
491+ i = 0
492+ j += 1
493+
494+ return out
495+
496+ chan_out = output .shape [0 ]
497+ if chan_out == 1 : # single prediction channel: create diagonal split
498+ im = diagonal_split (input_ , np .repeat (output , 3 , axis = 0 ))
499+ elif chan_out == 3 : # three prediction channel: create diagonal split with rgb
500+ im = diagonal_split (input_ , output )
501+ else : # otherwise create grid image
502+ im = grid_im (input_ , output )
503+
504+ # to channel last
505+ imageio .imwrite (cover_path , im .transpose ((1 , 2 , 0 )))
506+ return [_ensure_local (cover_path , root )]
507+
508+
420509def _ensure_local (source : Union [Path , URI , str , list ], root : Path ) -> Union [Path , URI , list ]:
421510 """ensure source is local relative path in root"""
422511 if isinstance (source , list ):
@@ -440,17 +529,18 @@ def _ensure_local_or_url(source: Union[Path, URI, str, list], root: Path) -> Uni
440529
441530
442531def build_model (
532+ # model or tensor specific and required
443533 weight_uri : str ,
444534 test_inputs : List [Union [str , Path ]],
445535 test_outputs : List [Union [str , Path ]],
536+ input_axes : List [str ],
537+ output_axes : List [str ],
446538 # general required
447539 name : str ,
448540 description : str ,
449541 authors : List [Dict [str , str ]],
450542 tags : List [Union [str , Path ]],
451- license : str ,
452543 documentation : Union [str , Path ],
453- covers : List [str ],
454544 cite : Dict [str , str ],
455545 output_path : Union [str , Path ],
456546 # model specific optional
@@ -463,19 +553,19 @@ def build_model(
463553 input_names : Optional [List [str ]] = None ,
464554 input_step : Optional [List [List [int ]]] = None ,
465555 input_min_shape : Optional [List [List [int ]]] = None ,
466- input_axes : Optional [List [str ]] = None ,
467556 input_data_range : Optional [List [List [Union [int , str ]]]] = None ,
468557 output_names : Optional [List [str ]] = None ,
469558 output_reference : Optional [List [str ]] = None ,
470559 output_scale : Optional [List [List [int ]]] = None ,
471560 output_offset : Optional [List [List [int ]]] = None ,
472- output_axes : Optional [List [str ]] = None ,
473561 output_data_range : Optional [List [List [Union [int , str ]]]] = None ,
474562 halo : Optional [List [List [int ]]] = None ,
475563 preprocessing : Optional [List [Dict [str , Dict [str , Union [int , float , str ]]]]] = None ,
476564 postprocessing : Optional [List [Dict [str , Dict [str , Union [int , float , str ]]]]] = None ,
477565 pixel_sizes : Optional [List [Dict [str , float ]]] = None ,
478566 # general optional
567+ license : Optional [str ] = None ,
568+ covers : Optional [List [str ]] = None ,
479569 git_repo : Optional [str ] = None ,
480570 attachments : Optional [Dict [str , Union [str , List [str ]]]] = None ,
481571 packaged_by : Optional [List [str ]] = None ,
@@ -501,13 +591,14 @@ def build_model(
501591 weight_uri="test_weights.pt",
502592 test_inputs=["./test_inputs"],
503593 test_outputs=["./test_outputs"],
594+ input_axes=["bcyx"],
595+ output_axes=["bcyx"],
504596 name="my-model",
505597 description="My very fancy model.",
506598 authors=[{"name": "John Doe", "affiliation": "My Institute"}],
507599 tags=["segmentation", "light sheet data"],
508- license="CC-BY",
600+ license="CC-BY-4.0 ",
509601 documentation="./documentation.md",
510- covers=["./my_cover.png"],
511602 cite={"Architecture": "https://my_architecture.com"},
512603 output_path="my-model.zip"
513604 )
@@ -517,13 +608,13 @@ def build_model(
517608 weight_uri: the url or relative local file path to the weight file for this model.
518609 test_inputs: list of test input files stored in numpy format.
519610 test_outputs: list of test outputs corresponding to test_inputs, stored in numpy format.
611+ input_axes: axis names of the input tensors.
612+ output_axes: axiss names of the output tensors.
520613 name: name of this model.
521614 description: short description of this model.
522615 authors: the authors of this model.
523616 tags: list of tags for this model.
524- license: the license for this model.
525617 documentation: relative file path to markdown documentation for this model.
526- covers: list of relative file paths for cover images.
527618 cite: citations for this model.
528619 output_path: where to save the zipped model package.
529620 source: the file with the source code for the model architecture and the corresponding class.
@@ -534,19 +625,20 @@ def build_model(
534625 input_names: names of the input tensors.
535626 input_step: minimal valid increase of the input tensor shape.
536627 input_min_shape: minimal input tensor shape.
537- input_axes: axes names for the input tensor.
538628 input_data_range: valid data range for the input tensor.
539629 output_names: names of the output tensors.
540630 output_reference: name of the input reference tensor used to cimpute the output tensor shape.
541631 output_scale: multiplicative factor to compute the output tensor shape.
542632 output_offset: additive term to compute the output tensor shape.
543- output_axes: axes names of the output tensor.
544633 output_data_range: valid data range for the output tensor.
545634 halo: halo to be cropped from the output tensor.
546635 preprocessing: list of preprocessing operations for the input.
547636 postprocessing: list of postprocessing operations for the output.
548637 pixel_sizes: the pixel sizes for the input tensors, only for spatial axes.
549638 This information is currently only used by deepimagej, but will be added to the spec soon.
639+ license: the license for this model. By default CC-BY-4.0 will be set as license.
640+ covers: list of file paths for cover images.
641+ By default a cover will be generated from the input and output data.
550642 git_repo: reference git repository for this model.
551643 attachments: list of additional files to package with the model.
552644 packaged_by: list of authors that have packaged this model.
@@ -582,7 +674,6 @@ def build_model(
582674
583675 input_step = n_inputs * [None ] if input_step is None else input_step
584676 input_min_shape = n_inputs * [None ] if input_min_shape is None else input_min_shape
585- input_axes = n_inputs * [None ] if input_axes is None else input_axes
586677 input_data_range = n_inputs * [None ] if input_data_range is None else input_data_range
587678 preprocessing = n_inputs * [None ] if preprocessing is None else preprocessing
588679
@@ -602,7 +693,6 @@ def build_model(
602693 output_reference = n_outputs * [None ] if output_reference is None else output_reference
603694 output_scale = n_outputs * [None ] if output_scale is None else output_scale
604695 output_offset = n_outputs * [None ] if output_offset is None else output_offset
605- output_axes = n_outputs * [None ] if output_axes is None else output_axes
606696 output_data_range = n_outputs * [None ] if output_data_range is None else output_data_range
607697 postprocessing = n_outputs * [None ] if postprocessing is None else postprocessing
608698 halo = n_outputs * [None ] if halo is None else halo
@@ -641,7 +731,12 @@ def build_model(
641731 authors = _build_authors (authors )
642732 cite = _build_cite (cite )
643733 documentation = _ensure_local (documentation , root )
644- covers = _ensure_local (covers , root )
734+ if covers is None :
735+ covers = _generate_covers (root / test_inputs [0 ], root / test_outputs [0 ], input_axes [0 ], output_axes [0 ], root )
736+ else :
737+ covers = _ensure_local (covers , root )
738+ if license is None :
739+ license = "CC-BY-4.0"
645740
646741 # parse the weights
647742 weights , tmp_archtecture = _get_weights (
0 commit comments