Skip to content

Commit eb55531

Browse files
Add cover generation and default license to build_model, make input/output_axes mandatory
1 parent 1eb7b86 commit eb55531

File tree

2 files changed

+134
-27
lines changed

2 files changed

+134
-27
lines changed

bioimageio/core/build_spec/build_model.py

Lines changed: 120 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from shutil import copyfile
66
from typing import Any, Dict, List, Optional, Tuple, Union
77

8+
import imageio
89
import numpy as np
910
import requests
1011

@@ -172,14 +173,6 @@ def _get_data_range(data_range, dtype):
172173
return data_range
173174

174175

175-
def _get_axes(axes, ndim):
176-
if axes is None:
177-
assert ndim in (2, 4, 5)
178-
default_axes = {2: "bc", 4: "bcyx", 5: "bczyx"}
179-
axes = default_axes[ndim]
180-
return axes
181-
182-
183176
def _get_input_tensor(path, name, step, min_shape, data_range, axes, preprocessing):
184177
test_in = np.load(path)
185178
shape = test_in.shape
@@ -189,9 +182,7 @@ def _get_input_tensor(path, name, step, min_shape, data_range, axes, preprocessi
189182
else:
190183
shape_description = {"min": shape if min_shape is None else min_shape, "step": step}
191184

192-
axes = _get_axes(axes, test_in.ndim)
193185
data_range = _get_data_range(data_range, test_in.dtype)
194-
195186
kwargs = {}
196187
if preprocessing is not None:
197188
kwargs["preprocessing"] = [{"name": k, "kwargs": v} for k, v in preprocessing.items()]
@@ -219,9 +210,7 @@ def _get_output_tensor(path, name, reference_tensor, scale, offset, axes, data_r
219210
assert offset is not None
220211
shape_description = {"reference_tensor": reference_tensor, "scale": scale, "offset": offset}
221212

222-
axes = _get_axes(axes, test_out.ndim)
223213
data_range = _get_data_range(data_range, test_out.dtype)
224-
225214
kwargs = {}
226215
if postprocessing is not None:
227216
kwargs["postprocessing"] = [{"name": k, "kwargs": v} for k, v in postprocessing.items()]
@@ -417,6 +406,106 @@ def write_im(path, im, axes):
417406
return [Path(p.name) for p in sample_in_paths], [Path(p.name) for p in sample_out_paths]
418407

419408

409+
# create better cover images for 3d data and non-image outputs
410+
def _generate_covers(in_path, out_path, input_axes, output_axes, root):
411+
412+
def normalize(data, axis, eps=1e-7):
413+
data = data.astype('float32')
414+
data -= data.min(axis=axis, keepdims=True)
415+
data /= (data.max(axis=axis, keepdims=True) + eps)
416+
return data
417+
418+
def to_image(data, data_axes):
419+
assert data.ndim in (4, 5)
420+
421+
# transpose the data to "bczyx" / "bcyx" order
422+
axes = "bczyx" if data.ndim == 5 else "bcyx"
423+
assert set(data_axes) == set(axes)
424+
if axes != data_axes:
425+
ax_permutation = tuple(data_axes.index(ax) for ax in axes)
426+
data = data.transpose(ax_permutation)
427+
428+
# select single image with channels from the data
429+
if data.ndim == 5:
430+
z0 = data.shape[2] // 2
431+
data = data[0, :, z0]
432+
else:
433+
data = data[0, :]
434+
435+
# normalize the data and map to 8 bit
436+
data = normalize(data, axis=(1, 2))
437+
data = (data * 255).astype("uint8")
438+
return data
439+
440+
cover_path = os.path.join(root, "cover.png")
441+
input_, output = np.load(in_path), np.load(out_path)
442+
443+
input_ = to_image(input_, input_axes)
444+
# this is not image data so we only save the input image
445+
if output.ndim < 4:
446+
imageio.imwrite(cover_path, input_.transpose((1, 2, 0)))
447+
return [_ensure_local(cover_path, root)]
448+
output = to_image(output, output_axes)
449+
450+
chan_in = input_.shape[0]
451+
# make sure the input is rgb
452+
if chan_in == 1: # single channel -> repeat it 3 times
453+
input_ = np.repeat(input_, 3, axis=0)
454+
elif chan_in != 3: # != 3 channels -> take first channe and repeat it 3 times
455+
input_ = np.repeat(input_[0:1], 3, axis=0)
456+
457+
im_shape = input_.shape[1:]
458+
if im_shape != output.shape[1:]: # just return the input image if shapes don"t agree
459+
return input_
460+
461+
def diagonal_split(im0, im1):
462+
assert im0.shape[0] == im1.shape[0] == 3
463+
n, m = im_shape
464+
out = np.ones((3, n, m), dtype="uint8")
465+
for c in range(3):
466+
outc = np.tril(im0[c])
467+
mask = outc == 0
468+
outc[mask] = np.triu(im1[c])[mask]
469+
out[c] = outc
470+
return out
471+
472+
def grid_im(im0, im1):
473+
ims_per_row = 3
474+
n_chan = im1.shape[0]
475+
n_images = n_chan + 1
476+
n_rows = int(np.ceil(float(n_images) / ims_per_row))
477+
478+
n, m = im_shape
479+
x, y = ims_per_row * n, n_rows * m
480+
out = np.zeros((3, y, x))
481+
images = [im0] + [np.repeat(im1[i:i+1], 3, axis=0) for i in range(n_chan)]
482+
483+
i, j = 0, 0
484+
for im in images:
485+
x0, x1 = i * n, (i + 1) * n
486+
y0, y1 = j * m, (j + 1) * m
487+
out[:, y0:y1, x0:x1] = im
488+
489+
i += 1
490+
if i == ims_per_row:
491+
i = 0
492+
j += 1
493+
494+
return out
495+
496+
chan_out = output.shape[0]
497+
if chan_out == 1: # single prediction channel: create diagonal split
498+
im = diagonal_split(input_, np.repeat(output, 3, axis=0))
499+
elif chan_out == 3: # three prediction channel: create diagonal split with rgb
500+
im = diagonal_split(input_, output)
501+
else: # otherwise create grid image
502+
im = grid_im(input_, output)
503+
504+
# to channel last
505+
imageio.imwrite(cover_path, im.transpose((1, 2, 0)))
506+
return [_ensure_local(cover_path, root)]
507+
508+
420509
def _ensure_local(source: Union[Path, URI, str, list], root: Path) -> Union[Path, URI, list]:
421510
"""ensure source is local relative path in root"""
422511
if isinstance(source, list):
@@ -440,17 +529,18 @@ def _ensure_local_or_url(source: Union[Path, URI, str, list], root: Path) -> Uni
440529

441530

442531
def build_model(
532+
# model or tensor specific and required
443533
weight_uri: str,
444534
test_inputs: List[Union[str, Path]],
445535
test_outputs: List[Union[str, Path]],
536+
input_axes: List[str],
537+
output_axes: List[str],
446538
# general required
447539
name: str,
448540
description: str,
449541
authors: List[Dict[str, str]],
450542
tags: List[Union[str, Path]],
451-
license: str,
452543
documentation: Union[str, Path],
453-
covers: List[str],
454544
cite: Dict[str, str],
455545
output_path: Union[str, Path],
456546
# model specific optional
@@ -463,19 +553,19 @@ def build_model(
463553
input_names: Optional[List[str]] = None,
464554
input_step: Optional[List[List[int]]] = None,
465555
input_min_shape: Optional[List[List[int]]] = None,
466-
input_axes: Optional[List[str]] = None,
467556
input_data_range: Optional[List[List[Union[int, str]]]] = None,
468557
output_names: Optional[List[str]] = None,
469558
output_reference: Optional[List[str]] = None,
470559
output_scale: Optional[List[List[int]]] = None,
471560
output_offset: Optional[List[List[int]]] = None,
472-
output_axes: Optional[List[str]] = None,
473561
output_data_range: Optional[List[List[Union[int, str]]]] = None,
474562
halo: Optional[List[List[int]]] = None,
475563
preprocessing: Optional[List[Dict[str, Dict[str, Union[int, float, str]]]]] = None,
476564
postprocessing: Optional[List[Dict[str, Dict[str, Union[int, float, str]]]]] = None,
477565
pixel_sizes: Optional[List[Dict[str, float]]] = None,
478566
# general optional
567+
license: Optional[str] = None,
568+
covers: Optional[List[str]] = None,
479569
git_repo: Optional[str] = None,
480570
attachments: Optional[Dict[str, Union[str, List[str]]]] = None,
481571
packaged_by: Optional[List[str]] = None,
@@ -501,13 +591,14 @@ def build_model(
501591
weight_uri="test_weights.pt",
502592
test_inputs=["./test_inputs"],
503593
test_outputs=["./test_outputs"],
594+
input_axes=["bcyx"],
595+
output_axes=["bcyx"],
504596
name="my-model",
505597
description="My very fancy model.",
506598
authors=[{"name": "John Doe", "affiliation": "My Institute"}],
507599
tags=["segmentation", "light sheet data"],
508-
license="CC-BY",
600+
license="CC-BY-4.0",
509601
documentation="./documentation.md",
510-
covers=["./my_cover.png"],
511602
cite={"Architecture": "https://my_architecture.com"},
512603
output_path="my-model.zip"
513604
)
@@ -517,13 +608,13 @@ def build_model(
517608
weight_uri: the url or relative local file path to the weight file for this model.
518609
test_inputs: list of test input files stored in numpy format.
519610
test_outputs: list of test outputs corresponding to test_inputs, stored in numpy format.
611+
input_axes: axis names of the input tensors.
612+
output_axes: axiss names of the output tensors.
520613
name: name of this model.
521614
description: short description of this model.
522615
authors: the authors of this model.
523616
tags: list of tags for this model.
524-
license: the license for this model.
525617
documentation: relative file path to markdown documentation for this model.
526-
covers: list of relative file paths for cover images.
527618
cite: citations for this model.
528619
output_path: where to save the zipped model package.
529620
source: the file with the source code for the model architecture and the corresponding class.
@@ -534,19 +625,20 @@ def build_model(
534625
input_names: names of the input tensors.
535626
input_step: minimal valid increase of the input tensor shape.
536627
input_min_shape: minimal input tensor shape.
537-
input_axes: axes names for the input tensor.
538628
input_data_range: valid data range for the input tensor.
539629
output_names: names of the output tensors.
540630
output_reference: name of the input reference tensor used to cimpute the output tensor shape.
541631
output_scale: multiplicative factor to compute the output tensor shape.
542632
output_offset: additive term to compute the output tensor shape.
543-
output_axes: axes names of the output tensor.
544633
output_data_range: valid data range for the output tensor.
545634
halo: halo to be cropped from the output tensor.
546635
preprocessing: list of preprocessing operations for the input.
547636
postprocessing: list of postprocessing operations for the output.
548637
pixel_sizes: the pixel sizes for the input tensors, only for spatial axes.
549638
This information is currently only used by deepimagej, but will be added to the spec soon.
639+
license: the license for this model. By default CC-BY-4.0 will be set as license.
640+
covers: list of file paths for cover images.
641+
By default a cover will be generated from the input and output data.
550642
git_repo: reference git repository for this model.
551643
attachments: list of additional files to package with the model.
552644
packaged_by: list of authors that have packaged this model.
@@ -582,7 +674,6 @@ def build_model(
582674

583675
input_step = n_inputs * [None] if input_step is None else input_step
584676
input_min_shape = n_inputs * [None] if input_min_shape is None else input_min_shape
585-
input_axes = n_inputs * [None] if input_axes is None else input_axes
586677
input_data_range = n_inputs * [None] if input_data_range is None else input_data_range
587678
preprocessing = n_inputs * [None] if preprocessing is None else preprocessing
588679

@@ -602,7 +693,6 @@ def build_model(
602693
output_reference = n_outputs * [None] if output_reference is None else output_reference
603694
output_scale = n_outputs * [None] if output_scale is None else output_scale
604695
output_offset = n_outputs * [None] if output_offset is None else output_offset
605-
output_axes = n_outputs * [None] if output_axes is None else output_axes
606696
output_data_range = n_outputs * [None] if output_data_range is None else output_data_range
607697
postprocessing = n_outputs * [None] if postprocessing is None else postprocessing
608698
halo = n_outputs * [None] if halo is None else halo
@@ -641,7 +731,12 @@ def build_model(
641731
authors = _build_authors(authors)
642732
cite = _build_cite(cite)
643733
documentation = _ensure_local(documentation, root)
644-
covers = _ensure_local(covers, root)
734+
if covers is None:
735+
covers = _generate_covers(root / test_inputs[0], root / test_outputs[0], input_axes[0], output_axes[0], root)
736+
else:
737+
covers = _ensure_local(covers, root)
738+
if license is None:
739+
license = "CC-BY-4.0"
645740

646741
# parse the weights
647742
weights, tmp_archtecture = _get_weights(

tests/build_spec/test_build_spec.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def _test_build_spec(
1313
opset_version=None,
1414
use_implicit_output_shape=False,
1515
add_deepimagej_config=False,
16+
use_original_covers=False,
1617
):
1718
from bioimageio.core.build_spec import build_model
1819

@@ -39,7 +40,9 @@ def _test_build_spec(
3940

4041
dep_file = None if model_spec.dependencies is missing else resolve_source(model_spec.dependencies.file, root)
4142
authors = [{"name": auth.name, "affiliation": auth.affiliation} for auth in model_spec.authors]
42-
covers = resolve_source(model_spec.covers, root)
43+
44+
input_axes = [input_.axes for input_ in model_spec.inputs]
45+
output_axes = [output.axes for output in model_spec.outputs]
4346
preprocessing = [
4447
None if input_.preprocessing == missing else {preproc.name: preproc.kwargs for preproc in input_.preprocessing}
4548
for input_ in model_spec.inputs
@@ -48,6 +51,7 @@ def _test_build_spec(
4851
None if output.postprocessing == missing else {preproc.name: preproc.kwargs for preproc in output.preprocessing}
4952
for output in model_spec.outputs
5053
]
54+
5155
kwargs = dict(
5256
weight_uri=weight_source,
5357
test_inputs=resolve_source(model_spec.test_inputs, root),
@@ -58,11 +62,12 @@ def _test_build_spec(
5862
tags=model_spec.tags,
5963
license=model_spec.license,
6064
documentation=model_spec.documentation,
61-
covers=covers,
6265
dependencies=dep_file,
6366
cite=cite,
6467
root=model_spec.root_path,
6568
weight_type=weight_type_,
69+
input_axes=input_axes,
70+
output_axes=output_axes,
6671
preprocessing=preprocessing,
6772
postprocessing=postprocessing,
6873
output_path=out_path,
@@ -83,6 +88,8 @@ def _test_build_spec(
8388
kwargs["output_offset"] = [[0.0, 0.0, 0.0, 0.0]]
8489
if add_deepimagej_config:
8590
kwargs["pixel_sizes"] = [{"x": 5.0, "y": 5.0}]
91+
if use_original_covers:
92+
kwargs["covers"] = resolve_source(model_spec.covers, root)
8693

8794
build_model(**kwargs)
8895
assert out_path.exists()
@@ -140,3 +147,8 @@ def test_build_spec_deepimagej_keras(unet2d_keras, tmp_path):
140147
_test_build_spec(
141148
unet2d_keras, tmp_path / "model.zip", "keras_hdf5", add_deepimagej_config=True, tensorflow_version="1.12"
142149
)
150+
151+
152+
# test with original covers
153+
def test_build_spec_with_original_covers(unet2d_nuclei_broad_model, tmp_path):
154+
_test_build_spec(unet2d_nuclei_broad_model, tmp_path / "model.zip", "torchscript", use_original_covers=True)

0 commit comments

Comments
 (0)