Skip to content

Commit 06ec7d5

Browse files
Merge pull request #176 from bioimage-io/more-tf-features
More tf features
2 parents 1200b5a + 585859d commit 06ec7d5

File tree

13 files changed

+261
-25
lines changed

13 files changed

+261
-25
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ jobs:
8080
uses: conda-incubator/setup-miniconda@v2
8181
with:
8282
auto-update-conda: true
83+
mamba-version: "*"
8384
channel-priority: strict
8485
activate-environment: bio-core-tf
8586
environment-file: dev/environment-tf.yaml

bioimageio/core/__main__.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@
2323
except ImportError:
2424
torch_converter = None
2525

26+
try:
27+
from bioimageio.core.weight_converter import keras as keras_converter
28+
except ImportError:
29+
keras_converter = None
30+
2631

2732
# extend help/version string by core version
2833
help_version_core = f"bioimageio.core {__version__}"
@@ -231,7 +236,8 @@ def convert_torch_weights_to_onnx(
231236
use_tracing: bool = typer.Option(True, help="Whether to use torch.jit tracing or scripting."),
232237
verbose: bool = typer.Option(True, help="Verbosity"),
233238
) -> int:
234-
return torch_converter.convert_weights_to_onnx(model_rdf, output_path, opset_version, use_tracing, verbose)
239+
ret_code = torch_converter.convert_weights_to_onnx(model_rdf, output_path, opset_version, use_tracing, verbose)
240+
sys.exit(ret_code)
235241

236242
convert_torch_weights_to_onnx.__doc__ = torch_converter.convert_weights_to_onnx.__doc__
237243

@@ -249,5 +255,22 @@ def convert_torch_weights_to_torchscript(
249255
convert_torch_weights_to_torchscript.__doc__ = torch_converter.convert_weights_to_pytorch_script.__doc__
250256

251257

258+
if keras_converter is not None:
259+
260+
@app.command()
261+
def convert_keras_weights_to_tensorflow(
262+
model_rdf: Path = typer.Argument(
263+
..., help="Path to the model resource description file (rdf.yaml) or zipped model."
264+
),
265+
output_path: Path = typer.Argument(..., help="Where to save the tensorflow weights."),
266+
) -> int:
267+
ret_code = keras_converter.convert_weights_to_tensorflow_saved_model_bundle(model_rdf, output_path)
268+
sys.exit(ret_code)
269+
270+
convert_keras_weights_to_tensorflow.__doc__ = (
271+
keras_converter.convert_weights_to_tensorflow_saved_model_bundle.__doc__
272+
)
273+
274+
252275
if __name__ == "__main__":
253276
app()

bioimageio/core/build_spec/build_model.py

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,16 @@ def _get_pytorch_state_dict_weight_kwargs(architecture, model_kwargs, root):
7878
return weight_kwargs, tmp_archtecture
7979

8080

81-
def _get_weights(original_weight_source, weight_type, root, architecture=None, model_kwargs=None, **kwargs):
81+
def _get_weights(
82+
original_weight_source,
83+
weight_type,
84+
root,
85+
architecture=None,
86+
model_kwargs=None,
87+
tensorflow_version=None,
88+
opset_version=None,
89+
**kwargs,
90+
):
8291
weight_path = resolve_source(original_weight_source, root)
8392
if weight_type is None:
8493
weight_type = _infer_weight_type(weight_path)
@@ -98,8 +107,10 @@ def _get_weights(original_weight_source, weight_type, root, architecture=None, m
98107
)
99108

100109
elif weight_type == "onnx":
110+
if opset_version is None:
111+
raise ValueError("opset_version needs to be passed for building an onnx model")
101112
weights = model_spec.raw_nodes.OnnxWeightsEntry(
102-
source=weight_source, sha256=weight_hash, opset_version=kwargs.get("opset_version", 12), **attachments
113+
source=weight_source, sha256=weight_hash, opset_version=opset_version, **attachments
103114
)
104115

105116
elif weight_type == "pytorch_script":
@@ -108,26 +119,32 @@ def _get_weights(original_weight_source, weight_type, root, architecture=None, m
108119
)
109120

110121
elif weight_type == "keras_hdf5":
122+
if tensorflow_version is None:
123+
raise ValueError("tensorflow_version needs to be passed for building a keras model")
111124
weights = model_spec.raw_nodes.KerasHdf5WeightsEntry(
112125
source=weight_source,
113126
sha256=weight_hash,
114-
tensorflow_version=kwargs.get("tensorflow_version", "1.15"),
127+
tensorflow_version=tensorflow_version,
115128
**attachments,
116129
)
117130

118131
elif weight_type == "tensorflow_saved_model_bundle":
132+
if tensorflow_version is None:
133+
raise ValueError("tensorflow_version needs to be passed for building a tensorflow model")
119134
weights = model_spec.raw_nodes.TensorflowSavedModelBundleWeightsEntry(
120135
source=weight_source,
121136
sha256=weight_hash,
122-
tensorflow_version=kwargs.get("tensorflow_version", "1.15"),
137+
tensorflow_version=tensorflow_version,
123138
**attachments,
124139
)
125140

126141
elif weight_type == "tensorflow_js":
142+
if tensorflow_version is None:
143+
raise ValueError("tensorflow_version needs to be passed for building a tensorflow_js model")
127144
weights = model_spec.raw_nodes.TensorflowJsWeightsEntry(
128145
source=weight_source,
129146
sha256=weight_hash,
130-
tensorflow_version=kwargs.get("tensorflow_version", "1.15"),
147+
tensorflow_version=tensorflow_version,
131148
**attachments,
132149
)
133150

@@ -471,6 +488,8 @@ def build_model(
471488
links: Optional[List[str]] = None,
472489
root: Optional[Union[Path, str]] = None,
473490
add_deepimagej_config: bool = False,
491+
tensorflow_version: Optional[str] = None,
492+
opset_version: Optional[int] = None,
474493
**weight_kwargs,
475494
):
476495
"""Create a zipped bioimage.io model.
@@ -539,7 +558,10 @@ def build_model(
539558
dependencies: relative path to file with dependencies for this model.
540559
root: optional root path for relative paths. This can be helpful when building a spec from another model spec.
541560
add_deepimagej_config: add the deepimagej config to the model.
542-
weight_kwargs: keyword arguments for this weight type, e.g. "tensorflow_version".
561+
tensorflow_version: the tensorflow version used for training the model.
562+
Needs to be passed for tensorflow or keras models.
563+
opset_version: the opset version used in this model. Needs to be passed for onnx models.
564+
weight_kwargs: additional keyword arguments for this weight type.
543565
"""
544566
if root is None:
545567
root = "."
@@ -624,7 +646,16 @@ def build_model(
624646
covers = _ensure_local(covers, root)
625647

626648
# parse the weights
627-
weights, tmp_archtecture = _get_weights(weight_uri, weight_type, root, architecture, model_kwargs, **weight_kwargs)
649+
weights, tmp_archtecture = _get_weights(
650+
weight_uri,
651+
weight_type,
652+
root,
653+
architecture,
654+
model_kwargs,
655+
tensorflow_version=tensorflow_version,
656+
opset_version=opset_version,
657+
**weight_kwargs,
658+
)
628659

629660
# validate the sample inputs and outputs (if given)
630661
if sample_inputs is not None:
@@ -732,11 +763,24 @@ def add_weights(
732763
weight_uri: Union[str, Path],
733764
weight_type: Optional[str] = None,
734765
output_path: Optional[Union[str, Path]] = None,
766+
architecture: Optional[str] = None,
767+
model_kwargs: Optional[Dict[str, Union[int, float, str]]] = None,
768+
tensorflow_version: Optional[str] = None,
769+
opset_version: Optional[str] = None,
735770
**weight_kwargs,
736771
):
737772
"""Add weight entry to bioimage.io model."""
738773
# we need to pass the weight path as abs path to avoid confusion with different root directories
739-
new_weights, tmp_arch = _get_weights(Path(weight_uri).absolute(), weight_type, root=Path("."), **weight_kwargs)
774+
new_weights, tmp_arch = _get_weights(
775+
Path(weight_uri).absolute(),
776+
weight_type,
777+
root=Path("."),
778+
architecture=architecture,
779+
model_kwargs=model_kwargs,
780+
tensorflow_version=tensorflow_version,
781+
opset_version=opset_version,
782+
**weight_kwargs,
783+
)
740784
model.weights.update(new_weights)
741785
if output_path is not None:
742786
model_package = export_resource_package(model, output_path=output_path)

bioimageio/core/prediction_pipeline/_model_adapters/_keras_model_adapter.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,33 @@
44
# by default, we use the keras integrated with tensorflow
55
try:
66
from tensorflow import keras
7+
import tensorflow as tf
8+
9+
TF_VERSION = tf.__version__
710
except Exception:
811
import keras
12+
13+
TF_VERSION = None
914
import xarray as xr
1015

1116
from ._model_adapter import ModelAdapter
1217

1318

1419
class KerasModelAdapter(ModelAdapter):
1520
def _load(self, *, devices: Optional[Sequence[str]] = None) -> None:
21+
try:
22+
model_tf_version = self.bioimageio_model.weights[self.weight_format].tensorflow_version.version
23+
except AttributeError:
24+
model_tf_version = None
25+
26+
if TF_VERSION is None or model_tf_version is None:
27+
warnings.warn("Could not check tensorflow versions. The prediction results may be wrong.")
28+
elif tuple(model_tf_version[:2]) != tuple(map(int, TF_VERSION.split(".")))[:2]:
29+
warnings.warn(
30+
f"Model tensorflow version {model_tf_version} does not match {TF_VERSION}."
31+
"The prediction results may be wrong"
32+
)
33+
1634
# TODO keras device management
1735
if devices is not None:
1836
warnings.warn(f"Device management is not implemented for keras yet, ignoring the devices {devices}")

bioimageio/core/prediction_pipeline/_model_adapters/_tensorflow_model_adapter.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,24 @@ def _load_model(self, weight_file):
3636

3737
def _load(self, *, devices: Optional[List[str]] = None):
3838
try:
39-
tf_version = self.bioimageio_model.weights[self.weight_format].tensorflow_version.version
39+
model_tf_version = self.bioimageio_model.weights[self.weight_format].tensorflow_version.version
4040
except AttributeError:
41-
tf_version = (1, 14, 0)
42-
tf_major_ver = tf_version[0]
41+
model_tf_version = None
42+
43+
tf_version = tf.__version__
44+
tf_major_and_minor = tuple(map(int, tf_version.split(".")))[:2]
45+
if model_tf_version is None:
46+
warnings.warn(
47+
"The model did not contain metadata about the tensorflow version used for training."
48+
f"Cannot check if it is compatible with tf {tf_version}. The prediction result may be wrong."
49+
)
50+
elif tuple(model_tf_version[:2]) != tf_major_and_minor:
51+
warnings.warn(
52+
f"Model tensorflow version {model_tf_version} does not match {tf_version}."
53+
"The prediction results may be wrong"
54+
)
55+
56+
tf_major_ver = tf_major_and_minor[0]
4357
assert tf_major_ver in (1, 2)
4458
self.use_keras_api = tf_major_ver > 1 or self.weight_format == KerasModelAdapter.weight_format
4559

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .tensorflow import convert_weights_to_tensorflow_saved_model_bundle
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import os
2+
import shutil
3+
from pathlib import Path
4+
from typing import Union
5+
from zipfile import ZipFile
6+
7+
import bioimageio.spec as spec
8+
from bioimageio.core import load_resource_description
9+
10+
import tensorflow
11+
from tensorflow import saved_model
12+
13+
14+
# adapted from
15+
# https://github.com/deepimagej/pydeepimagej/blob/master/pydeepimagej/yaml/create_config.py#L236
16+
def _convert_tf1(keras_weight_path, output_path, zip_weights):
17+
def build_tf_model():
18+
keras_model = keras.models.load_model(keras_weight_path)
19+
20+
builder = saved_model.builder.SavedModelBuilder(output_path)
21+
signature = saved_model.signature_def_utils.predict_signature_def(
22+
inputs={"input": keras_model.input}, outputs={"output": keras_model.output}
23+
)
24+
25+
signature_def_map = {saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature}
26+
27+
builder.add_meta_graph_and_variables(
28+
keras.backend.get_session(), [saved_model.tag_constants.SERVING], signature_def_map=signature_def_map
29+
)
30+
builder.save()
31+
32+
try:
33+
# try to build the tf model with the keras import from tensorflow
34+
from tensorflow import keras
35+
build_tf_model()
36+
except Exception:
37+
# if the above fails try to export with the standalone keras
38+
import keras
39+
40+
build_tf_model()
41+
42+
if zip_weights:
43+
zipped_model = f"{output_path}.zip"
44+
# zip the weights
45+
file_paths = []
46+
for folder_names, subfolder, filenames in os.walk(os.path.join(output_path)):
47+
for filename in filenames:
48+
# create complete filepath of file in directory
49+
file_paths.append(os.path.join(folder_names, filename))
50+
51+
with ZipFile(zipped_model, "w") as zip_obj:
52+
for f in file_paths:
53+
# Add file to zip
54+
zip_obj.write(f, os.path.relpath(f, output_path))
55+
56+
try:
57+
shutil.rmtree(output_path)
58+
except Exception:
59+
print("TensorFlow bundled model was not removed after compression")
60+
print("TensorFlow model exported to", zipped_model)
61+
else:
62+
print("TensorFlow model exported to", output_path)
63+
return 0
64+
65+
66+
def convert_weights_to_tensorflow_saved_model_bundle(
67+
model_spec: Union[str, Path, spec.model.raw_nodes.Model], output_path: Union[str, Path]
68+
):
69+
"""Convert model weights from format 'keras_hdf5' to 'tensorflow_saved_model_bundle'.
70+
71+
Adapted from
72+
https://github.com/deepimagej/pydeepimagej/blob/5aaf0e71f9b04df591d5ca596f0af633a7e024f5/pydeepimagej/yaml/create_config.py
73+
74+
Args:
75+
model_spec: location of the resource for the input bioimageio model
76+
output_path: where to save the tensorflow weights. This path must not exist yet.
77+
"""
78+
tf_major_ver = int(tensorflow.__version__.split(".")[0])
79+
80+
path_ = Path(output_path)
81+
if path_.suffix == ".zip":
82+
path_ = Path(os.path.splitext(path_)[0])
83+
zip_weights = True
84+
else:
85+
zip_weights = False
86+
87+
if path_.exists():
88+
raise ValueError(f"The ouptut directory at {path_} must not exist.")
89+
90+
model = load_resource_description(model_spec)
91+
assert "keras_hdf5" in model.weights
92+
weight_spec = model.weights["keras_hdf5"]
93+
weight_path = str(weight_spec.source)
94+
95+
if weight_spec.tensorflow_version:
96+
model_tf_major_ver = weight_spec.tensorflow_version.version[0]
97+
if model_tf_major_ver != tf_major_ver:
98+
raise RuntimeError(f"Tensorflow major versions of model {model_tf_major_ver} is not {tf_major_ver}")
99+
100+
if tf_major_ver == 1:
101+
return _convert_tf1(weight_path, str(path_), zip_weights)
102+
else:
103+
raise NotImplementedError

dev/environment-tf.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ dependencies:
1212
- pytest
1313
- python >=3.7,<3.8 # this environment is only available for python 3.7
1414
- xarray
15-
- tensorflow >=1.12,<2.0
15+
- tensorflow >1.14,<2.0
1616
- tifffile
1717
- pip:
1818
- keras==1.2.2

0 commit comments

Comments
 (0)