Skip to content

Commit 1200b5a

Browse files
Merge pull request #174 from bioimage-io/v04
Update to v0.4
2 parents e84c8d0 + c7dd960 commit 1200b5a

File tree

7 files changed

+135
-121
lines changed

7 files changed

+135
-121
lines changed

bioimageio/core/build_spec/build_model.py

Lines changed: 60 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from bioimageio.core import export_resource_package, load_raw_resource_description
1414
from bioimageio.core.resource_io.nodes import URI
1515
from bioimageio.core.resource_io.utils import resolve_local_source, resolve_source
16+
from bioimageio.spec.shared.raw_nodes import ImportableSourceFile, ImportableModule
1617

1718
try:
1819
from typing import get_args
@@ -53,67 +54,58 @@ def _infer_weight_type(path):
5354
raise ValueError(f"Could not infer weight type from extension {ext} for weight file {path}")
5455

5556

56-
def _get_weights(original_weight_source, weight_type, source, root, **kwargs):
57+
def _get_pytorch_state_dict_weight_kwargs(architecture, model_kwargs, root):
58+
assert architecture is not None
59+
tmp_archtecture = None
60+
weight_kwargs = {"kwargs": model_kwargs} if model_kwargs else {}
61+
arch = spec.shared.fields.ImportableSource().deserialize(architecture)
62+
if isinstance(arch, ImportableSourceFile):
63+
if os.path.isabs(arch.source_file):
64+
tmp_archtecture = Path("this_model_architecture.py")
65+
66+
copyfile(arch.source_file, root / tmp_archtecture)
67+
arch = ImportableSourceFile(arch.callable_name, tmp_archtecture)
68+
69+
arch_hash = _get_hash(root / arch.source_file)
70+
weight_kwargs["architecture_sha256"] = arch_hash
71+
elif isinstance(arch, ImportableModule):
72+
pass
73+
else:
74+
raise NotImplementedError(arch)
75+
76+
weight_kwargs["architecture"] = arch
77+
78+
return weight_kwargs, tmp_archtecture
79+
80+
81+
def _get_weights(original_weight_source, weight_type, root, architecture=None, model_kwargs=None, **kwargs):
5782
weight_path = resolve_source(original_weight_source, root)
5883
if weight_type is None:
5984
weight_type = _infer_weight_type(weight_path)
6085
weight_hash = _get_hash(weight_path)
6186

62-
tmp_source = None
63-
# if we have a ":" (or deprecated "::") this is a python file with class specified,
64-
# so we can compute the hash for it
65-
if source is not None and ":" in source:
66-
source_file, source_class = source.replace("::", ":").split(":")
67-
68-
# get the source path
69-
source_file = _ensure_local(source_file, root)
70-
source_hash = _get_hash(root / source_file)
71-
72-
# if not relative, create local copy (otherwise this will not work)
73-
if os.path.isabs(source_file):
74-
copyfile(source_file, "this_model_architecture.py")
75-
source = f"this_model_architecture.py:{source_class}"
76-
tmp_source = "this_model_architecture.py"
77-
else:
78-
source = f"{source_file}:{source_class}"
79-
source = spec.shared.fields.ImportableSource().deserialize(source)
80-
else:
81-
source_hash = None
82-
83-
if "weight_attachments" in kwargs:
84-
attachments = {"attachments": ["weight_attachments"]}
85-
else:
86-
attachments = {}
87-
87+
attachments = {"attachments": kwargs["weight_attachments"]} if "weight_attachments" in kwargs else {}
8888
weight_types = model_spec.raw_nodes.WeightsFormat
8989
weight_source = _ensure_local_or_url(original_weight_source, root)
9090

91+
tmp_archtecture = None
9192
if weight_type == "pytorch_state_dict":
92-
# pytorch-state-dict -> we need a source
93-
assert source is not None
93+
# pytorch-state-dict -> we need an architecture definition
94+
weight_kwargs, tmp_file = _get_pytorch_state_dict_weight_kwargs(architecture, model_kwargs, root)
95+
weight_kwargs.update(**attachments)
9496
weights = model_spec.raw_nodes.PytorchStateDictWeightsEntry(
95-
source=weight_source, sha256=weight_hash, **attachments
97+
source=weight_source, sha256=weight_hash, **weight_kwargs
9698
)
97-
language = "python"
98-
framework = "pytorch"
9999

100100
elif weight_type == "onnx":
101101
weights = model_spec.raw_nodes.OnnxWeightsEntry(
102102
source=weight_source, sha256=weight_hash, opset_version=kwargs.get("opset_version", 12), **attachments
103103
)
104-
language = None
105-
framework = None
106104

107105
elif weight_type == "pytorch_script":
108106
weights = model_spec.raw_nodes.PytorchScriptWeightsEntry(
109107
source=weight_source, sha256=weight_hash, **attachments
110108
)
111-
if source is None:
112-
language = None
113-
framework = None
114-
else:
115-
language = "python"
116-
framework = "pytorch"
117109

118110
elif weight_type == "keras_hdf5":
119111
weights = model_spec.raw_nodes.KerasHdf5WeightsEntry(
@@ -122,8 +114,6 @@ def _get_weights(original_weight_source, weight_type, source, root, **kwargs):
122114
tensorflow_version=kwargs.get("tensorflow_version", "1.15"),
123115
**attachments,
124116
)
125-
language = "python"
126-
framework = "tensorflow"
127117

128118
elif weight_type == "tensorflow_saved_model_bundle":
129119
weights = model_spec.raw_nodes.TensorflowSavedModelBundleWeightsEntry(
@@ -132,8 +122,6 @@ def _get_weights(original_weight_source, weight_type, source, root, **kwargs):
132122
tensorflow_version=kwargs.get("tensorflow_version", "1.15"),
133123
**attachments,
134124
)
135-
language = "python"
136-
framework = "tensorflow"
137125

138126
elif weight_type == "tensorflow_js":
139127
weights = model_spec.raw_nodes.TensorflowJsWeightsEntry(
@@ -142,16 +130,14 @@ def _get_weights(original_weight_source, weight_type, source, root, **kwargs):
142130
tensorflow_version=kwargs.get("tensorflow_version", "1.15"),
143131
**attachments,
144132
)
145-
language = None
146-
framework = None
147133

148134
elif weight_type in weight_types:
149135
raise ValueError(f"Weight type {weight_type} is not supported yet in 'build_spec'")
150136
else:
151137
raise ValueError(f"Invalid weight type {weight_type}, expect one of {weight_types}")
152138

153139
weights = {weight_type: weights}
154-
return weights, language, framework, source, source_hash, tmp_source
140+
return weights, tmp_archtecture
155141

156142

157143
def _get_data_range(data_range, dtype):
@@ -286,13 +272,15 @@ def _get_deepimagej_macro(name, kwargs, export_folder):
286272
else:
287273
raise ValueError(f"Macro {name} is not available, must be one of {macro_names}.")
288274

289-
macro = f"{name}.ijm"
290275
url = f"https://raw.githubusercontent.com/deepimagej/imagej-macros/master/bioimage.io/{macro}"
291276

292277
path = os.path.join(export_folder, macro)
293278
# use https://github.com/bioimage-io/core-bioimage-io-python/blob/main/bioimageio/core/resource_io/utils.py#L267
294279
# instead if the implementation is update s.t. an output path is accepted
295280
with requests.get(url, stream=True) as r:
281+
text = r.text
282+
if text.startswith("4"):
283+
raise RuntimeError(f"An error occured when downloading {url}: {r.text}")
296284
with open(path, "w") as f:
297285
f.write(r.text)
298286

@@ -451,18 +439,18 @@ def build_model(
451439
cite: Dict[str, str],
452440
output_path: Union[str, Path],
453441
# model specific optional
454-
source: Optional[str] = None,
442+
architecture: Optional[str] = None,
455443
model_kwargs: Optional[Dict[str, Union[int, float, str]]] = None,
456444
weight_type: Optional[str] = None,
457445
sample_inputs: Optional[List[str]] = None,
458446
sample_outputs: Optional[List[str]] = None,
459447
# tensor specific
460-
input_name: Optional[List[str]] = None,
448+
input_names: Optional[List[str]] = None,
461449
input_step: Optional[List[List[int]]] = None,
462450
input_min_shape: Optional[List[List[int]]] = None,
463451
input_axes: Optional[List[str]] = None,
464452
input_data_range: Optional[List[List[Union[int, str]]]] = None,
465-
output_name: Optional[List[str]] = None,
453+
output_names: Optional[List[str]] = None,
466454
output_reference: Optional[List[str]] = None,
467455
output_scale: Optional[List[List[int]]] = None,
468456
output_offset: Optional[List[List[int]]] = None,
@@ -526,12 +514,12 @@ def build_model(
526514
weight_type: the type of the weights.
527515
sample_inputs: list of sample inputs to demonstrate the model performance.
528516
sample_outputs: list of sample outputs corresponding to sample_inputs.
529-
input_name: name of the input tensor.
517+
input_names: names of the input tensors.
530518
input_step: minimal valid increase of the input tensor shape.
531519
input_min_shape: minimal input tensor shape.
532520
input_axes: axes names for the input tensor.
533521
input_data_range: valid data range for the input tensor.
534-
output_name: name of the output tensor.
522+
output_names: names of the output tensors.
535523
output_reference: name of the input reference tensor used to cimpute the output tensor shape.
536524
output_scale: multiplicative factor to compute the output tensor shape.
537525
output_offset: additive term to compute the output tensor shape.
@@ -567,7 +555,11 @@ def build_model(
567555
test_outputs = _ensure_local_or_url(test_outputs, root)
568556

569557
n_inputs = len(test_inputs)
570-
input_name = n_inputs * [None] if input_name is None else input_name
558+
if input_names is None:
559+
input_names = [f"input{i}" for i in range(n_inputs)]
560+
else:
561+
assert len(input_names) == len(test_inputs)
562+
571563
input_step = n_inputs * [None] if input_step is None else input_step
572564
input_min_shape = n_inputs * [None] if input_min_shape is None else input_min_shape
573565
input_axes = n_inputs * [None] if input_axes is None else input_axes
@@ -577,12 +569,16 @@ def build_model(
577569
inputs = [
578570
_get_input_tensor(root / test_in, name, step, min_shape, data_range, axes, preproc)
579571
for test_in, name, step, min_shape, axes, data_range, preproc in zip(
580-
test_inputs, input_name, input_step, input_min_shape, input_axes, input_data_range, preprocessing
572+
test_inputs, input_names, input_step, input_min_shape, input_axes, input_data_range, preprocessing
581573
)
582574
]
583575

584576
n_outputs = len(test_outputs)
585-
output_name = n_outputs * [None] if output_name is None else output_name
577+
if output_names is None:
578+
output_names = [f"output{i}" for i in range(n_outputs)]
579+
else:
580+
assert len(output_names) == len(test_outputs)
581+
586582
output_reference = n_outputs * [None] if output_reference is None else output_reference
587583
output_scale = n_outputs * [None] if output_scale is None else output_scale
588584
output_offset = n_outputs * [None] if output_offset is None else output_offset
@@ -595,7 +591,7 @@ def build_model(
595591
_get_output_tensor(root / test_out, name, reference, scale, offset, axes, data_range, postproc, hal)
596592
for test_out, name, reference, scale, offset, axes, data_range, postproc, hal in zip(
597593
test_outputs,
598-
output_name,
594+
output_names,
599595
output_reference,
600596
output_scale,
601597
output_offset,
@@ -628,9 +624,7 @@ def build_model(
628624
covers = _ensure_local(covers, root)
629625

630626
# parse the weights
631-
weights, language, framework, source, source_hash, tmp_source = _get_weights(
632-
weight_uri, weight_type, source, root, **weight_kwargs
633-
)
627+
weights, tmp_archtecture = _get_weights(weight_uri, weight_type, root, architecture, model_kwargs, **weight_kwargs)
634628

635629
# validate the sample inputs and outputs (if given)
636630
if sample_inputs is not None:
@@ -692,11 +686,6 @@ def build_model(
692686
"run_mode": run_mode,
693687
"sample_inputs": sample_inputs,
694688
"sample_outputs": sample_outputs,
695-
"framework": framework,
696-
"language": language,
697-
"source": source,
698-
"sha256": source_hash,
699-
"kwargs": model_kwargs,
700689
"links": links,
701690
}
702691
kwargs = {k: v for k, v in optional_kwargs.items() if v is not None}
@@ -731,8 +720,8 @@ def build_model(
731720
except Exception as e:
732721
raise e
733722
finally:
734-
if tmp_source is not None:
735-
os.remove(tmp_source)
723+
if tmp_archtecture is not None:
724+
os.remove(tmp_archtecture)
736725

737726
model = load_raw_resource_description(model_package)
738727
return model
@@ -746,12 +735,12 @@ def add_weights(
746735
**weight_kwargs,
747736
):
748737
"""Add weight entry to bioimage.io model."""
749-
# we need to patss the weight path as abs path to avoid confusion with different root directories
750-
new_weights = _get_weights(Path(weight_uri).absolute(), weight_type, source=None, root=Path("."), **weight_kwargs)[
751-
0
752-
]
738+
# we need to pass the weight path as abs path to avoid confusion with different root directories
739+
new_weights, tmp_arch = _get_weights(Path(weight_uri).absolute(), weight_type, root=Path("."), **weight_kwargs)
753740
model.weights.update(new_weights)
754741
if output_path is not None:
755742
model_package = export_resource_package(model, output_path=output_path)
756743
model = load_raw_resource_description(model_package)
744+
if tmp_arch is not None:
745+
os.remove(tmp_arch)
757746
return model

bioimageio/core/prediction_pipeline/_model_adapters/_keras_model_adapter.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import warnings
22
from typing import List, Optional, Sequence
33

4-
import keras
4+
# by default, we use the keras integrated with tensorflow
5+
try:
6+
from tensorflow import keras
7+
except Exception:
8+
import keras
59
import xarray as xr
610

711
from ._model_adapter import ModelAdapter

bioimageio/core/prediction_pipeline/_model_adapters/_pytorch_model_adapter.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,10 @@ def _unload(self) -> None:
5252

5353
@staticmethod
5454
def get_nn_instance(model_node: nodes.Model, **kwargs):
55-
assert isinstance(model_node.source, nodes.ImportedSource)
56-
57-
joined_kwargs = {} if model_node.kwargs is missing else dict(model_node.kwargs)
55+
weight_spec = model_node.weights.get("pytorch_state_dict")
56+
assert weight_spec is not None
57+
assert isinstance(weight_spec.architecture, nodes.ImportedSource)
58+
model_kwargs = weight_spec.kwargs
59+
joined_kwargs = {} if model_kwargs is missing else dict(model_kwargs)
5860
joined_kwargs.update(kwargs)
59-
return model_node.source(**joined_kwargs)
61+
return weight_spec.architecture(**joined_kwargs)

0 commit comments

Comments
 (0)