Skip to content

Commit f0d9b8d

Browse files
Add cli for keras weight converter
1 parent 381118c commit f0d9b8d

File tree

3 files changed

+32
-1
lines changed

3 files changed

+32
-1
lines changed

bioimageio/core/__main__.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@
2323
except ImportError:
2424
torch_converter = None
2525

26+
try:
27+
from bioimageio.core.weight_converter import keras as keras_converter
28+
except ImportError:
29+
keras_converter = None
30+
2631

2732
# extend help/version string by core version
2833
help_version_core = f"bioimageio.core {__version__}"
@@ -231,7 +236,8 @@ def convert_torch_weights_to_onnx(
231236
use_tracing: bool = typer.Option(True, help="Whether to use torch.jit tracing or scripting."),
232237
verbose: bool = typer.Option(True, help="Verbosity"),
233238
) -> int:
234-
return torch_converter.convert_weights_to_onnx(model_rdf, output_path, opset_version, use_tracing, verbose)
239+
ret_code = torch_converter.convert_weights_to_onnx(model_rdf, output_path, opset_version, use_tracing, verbose)
240+
sys.exit(ret_code)
235241

236242
convert_torch_weights_to_onnx.__doc__ = torch_converter.convert_weights_to_onnx.__doc__
237243

@@ -249,5 +255,21 @@ def convert_torch_weights_to_torchscript(
249255
convert_torch_weights_to_torchscript.__doc__ = torch_converter.convert_weights_to_pytorch_script.__doc__
250256

251257

258+
if keras_converter is not None:
259+
260+
@app.command()
261+
def convert_keras_weights_to_tensorflow(
262+
model_rdf: Path = typer.Argument(
263+
..., help="Path to the model resource description file (rdf.yaml) or zipped model."
264+
),
265+
output_path: Path = typer.Argument(..., help="Where to save the tensorflow weights."),
266+
) -> int:
267+
ret_code = keras_converter.convert_weights_to_tensorflow_saved_model_bundle(model_rdf, output_path)
268+
sys.exit(ret_code)
269+
270+
convert_keras_weights_to_tensorflow.__doc__ =\
271+
keras_converter.convert_weights_to_tensorflow_saved_model_bundle.__doc__
272+
273+
252274
if __name__ == "__main__":
253275
app()

bioimageio/core/weight_converter/keras/tensorflow.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from tensorflow import keras
1212

1313

14+
# adapted from
15+
# https://github.com/deepimagej/pydeepimagej/blob/master/pydeepimagej/yaml/create_config.py#L236
1416
def _convert_tf1(keras_weight_path, output_path, zip_weights):
1517
from tensorflow import saved_model
1618

tests/test_cli.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,10 @@ def test_torch_to_onnx(unet2d_nuclei_broad_model, tmp_path):
125125
ret = run_subprocess(["bioimageio", "convert-torch-weights-to-onnx", str(unet2d_nuclei_broad_model), str(out_path)])
126126
assert ret.returncode == 0, ret.stdout
127127
assert out_path.exists()
128+
129+
130+
def test_keras_to_tf(unet2d_keras, tmp_path):
131+
out_path = tmp_path / "weights.zip"
132+
ret = run_subprocess(["bioimageio", "convert-keras-weights-to-tensorflow", str(unet2d_keras), str(out_path)])
133+
assert ret.returncode == 0, ret.stdout
134+
assert out_path.exists()

0 commit comments

Comments
 (0)