Skip to content

Commit f804cb3

Browse files
Require passing onnx or tf version in build_model when using one of these weight formats
1 parent a53daca commit f804cb3

File tree

2 files changed

+61
-11
lines changed

2 files changed

+61
-11
lines changed

bioimageio/core/build_spec/build_model.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,16 @@ def _get_pytorch_state_dict_weight_kwargs(architecture, model_kwargs, root):
7878
return weight_kwargs, tmp_archtecture
7979

8080

81-
def _get_weights(original_weight_source, weight_type, root, architecture=None, model_kwargs=None, **kwargs):
81+
def _get_weights(
82+
original_weight_source,
83+
weight_type,
84+
root,
85+
architecture=None,
86+
model_kwargs=None,
87+
tensorflow_version=None,
88+
opset_version=None,
89+
**kwargs
90+
):
8291
weight_path = resolve_source(original_weight_source, root)
8392
if weight_type is None:
8493
weight_type = _infer_weight_type(weight_path)
@@ -98,8 +107,10 @@ def _get_weights(original_weight_source, weight_type, root, architecture=None, m
98107
)
99108

100109
elif weight_type == "onnx":
110+
if opset_version is None:
111+
raise ValueError("opset_version needs to be passed for building an onnx model")
101112
weights = model_spec.raw_nodes.OnnxWeightsEntry(
102-
source=weight_source, sha256=weight_hash, opset_version=kwargs.get("opset_version", 12), **attachments
113+
source=weight_source, sha256=weight_hash, opset_version=opset_version, **attachments
103114
)
104115

105116
elif weight_type == "pytorch_script":
@@ -108,26 +119,32 @@ def _get_weights(original_weight_source, weight_type, root, architecture=None, m
108119
)
109120

110121
elif weight_type == "keras_hdf5":
122+
if tensorflow_version is None:
123+
raise ValueError("tensorflow_version needs to be passed for building a keras model")
111124
weights = model_spec.raw_nodes.KerasHdf5WeightsEntry(
112125
source=weight_source,
113126
sha256=weight_hash,
114-
tensorflow_version=kwargs.get("tensorflow_version", "1.15"),
127+
tensorflow_version=tensorflow_version,
115128
**attachments,
116129
)
117130

118131
elif weight_type == "tensorflow_saved_model_bundle":
132+
if tensorflow_version is None:
133+
raise ValueError("tensorflow_version needs to be passed for building a tensorflow model")
119134
weights = model_spec.raw_nodes.TensorflowSavedModelBundleWeightsEntry(
120135
source=weight_source,
121136
sha256=weight_hash,
122-
tensorflow_version=kwargs.get("tensorflow_version", "1.15"),
137+
tensorflow_version=tensorflow_version,
123138
**attachments,
124139
)
125140

126141
elif weight_type == "tensorflow_js":
142+
if tensorflow_version is None:
143+
raise ValueError("tensorflow_version needs to be passed for building a tensorflow_js model")
127144
weights = model_spec.raw_nodes.TensorflowJsWeightsEntry(
128145
source=weight_source,
129146
sha256=weight_hash,
130-
tensorflow_version=kwargs.get("tensorflow_version", "1.15"),
147+
tensorflow_version=tensorflow_version,
131148
**attachments,
132149
)
133150

@@ -471,6 +488,8 @@ def build_model(
471488
links: Optional[List[str]] = None,
472489
root: Optional[Union[Path, str]] = None,
473490
add_deepimagej_config: bool = False,
491+
tensorflow_version: Optional[str] = None,
492+
opset_version: Optional[int] = None,
474493
**weight_kwargs,
475494
):
476495
"""Create a zipped bioimage.io model.
@@ -539,7 +558,10 @@ def build_model(
539558
dependencies: relative path to file with dependencies for this model.
540559
root: optional root path for relative paths. This can be helpful when building a spec from another model spec.
541560
add_deepimagej_config: add the deepimagej config to the model.
542-
weight_kwargs: keyword arguments for this weight type, e.g. "tensorflow_version".
561+
tensorflow_version: the tensorflow version used for training the model.
562+
Needs to be passed for tensorflow or keras models.
563+
opset_version: the opset version used in this model. Needs to be passed for onnx models.
564+
weight_kwargs: additional keyword arguments for this weight type.
543565
"""
544566
if root is None:
545567
root = "."
@@ -624,7 +646,16 @@ def build_model(
624646
covers = _ensure_local(covers, root)
625647

626648
# parse the weights
627-
weights, tmp_archtecture = _get_weights(weight_uri, weight_type, root, architecture, model_kwargs, **weight_kwargs)
649+
weights, tmp_archtecture = _get_weights(
650+
weight_uri,
651+
weight_type,
652+
root,
653+
architecture,
654+
model_kwargs,
655+
tensorflow_version=tensorflow_version,
656+
opset_version=opset_version,
657+
**weight_kwargs,
658+
)
628659

629660
# validate the sample inputs and outputs (if given)
630661
if sample_inputs is not None:
@@ -732,11 +763,25 @@ def add_weights(
732763
weight_uri: Union[str, Path],
733764
weight_type: Optional[str] = None,
734765
output_path: Optional[Union[str, Path]] = None,
766+
architecture: Optional[str] = None,
767+
model_kwargs: Optional[Dict[str, Union[int, float, str]]] = None,
768+
tensorflow_version: Optional[str] = None,
769+
opset_version: Optional[str] = None,
735770
**weight_kwargs,
736771
):
737-
"""Add weight entry to bioimage.io model."""
772+
"""Add weight entry to bioimage.io model.
773+
"""
738774
# we need to pass the weight path as abs path to avoid confusion with different root directories
739-
new_weights, tmp_arch = _get_weights(Path(weight_uri).absolute(), weight_type, root=Path("."), **weight_kwargs)
775+
new_weights, tmp_arch = _get_weights(
776+
Path(weight_uri).absolute(),
777+
weight_type,
778+
root=Path("."),
779+
architecture=architecture,
780+
model_kwargs=model_kwargs,
781+
tensorflow_version=tensorflow_version,
782+
opset_version=opset_version,
783+
**weight_kwargs,
784+
)
740785
model.weights.update(new_weights)
741786
if output_path is not None:
742787
model_package = export_resource_package(model, output_path=output_path)

tests/build_spec/test_build_spec.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ def _test_build_spec(
1010
out_path,
1111
weight_type,
1212
tensorflow_version=None,
13+
opset_version=None,
1314
use_implicit_output_shape=False,
1415
add_deepimagej_config=False,
1516
):
@@ -73,6 +74,8 @@ def _test_build_spec(
7374
kwargs["kwargs"] = model_kwargs
7475
if tensorflow_version is not None:
7576
kwargs["tensorflow_version"] = tensorflow_version
77+
if opset_version is not None:
78+
kwargs["opset_version"] = opset_version
7679
if use_implicit_output_shape:
7780
kwargs["input_name"] = ["input"]
7881
kwargs["output_reference"] = ["input"]
@@ -110,7 +113,7 @@ def test_build_spec_torchscript(any_torchscript_model, tmp_path):
110113

111114

112115
def test_build_spec_onnx(any_onnx_model, tmp_path):
113-
_test_build_spec(any_onnx_model, tmp_path / "model.zip", "onnx")
116+
_test_build_spec(any_onnx_model, tmp_path / "model.zip", "onnx", opset_version=12)
114117

115118

116119
def test_build_spec_keras(any_keras_model, tmp_path):
@@ -134,4 +137,6 @@ def test_build_spec_deepimagej(unet2d_nuclei_broad_model, tmp_path):
134137

135138

136139
def test_build_spec_deepimagej_keras(unet2d_keras, tmp_path):
137-
_test_build_spec(unet2d_keras, tmp_path / "model.zip", "keras_hdf5", add_deepimagej_config=True)
140+
_test_build_spec(
141+
unet2d_keras, tmp_path / "model.zip", "keras_hdf5", add_deepimagej_config=True, tensorflow_version="1.12"
142+
)

0 commit comments

Comments
 (0)