Skip to content

Commit 87b2548

Browse files
Try weight conversion with tensorflow.keras and keras
1 parent 1c6b70c commit 87b2548

File tree

1 file changed

+24
-15
lines changed

1 file changed

+24
-15
lines changed

bioimageio/core/weight_converter/keras/tensorflow.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,36 @@
88
from bioimageio.core import load_resource_description
99

1010
import tensorflow
11-
from tensorflow import keras
11+
from tensorflow import saved_model
1212

1313

1414
# adapted from
1515
# https://github.com/deepimagej/pydeepimagej/blob/master/pydeepimagej/yaml/create_config.py#L236
1616
def _convert_tf1(keras_weight_path, output_path, zip_weights):
17-
from tensorflow import saved_model
1817

19-
keras_model = keras.models.load_model(keras_weight_path)
20-
21-
builder = saved_model.builder.SavedModelBuilder(output_path)
22-
signature = saved_model.signature_def_utils.predict_signature_def(
23-
inputs={"input": keras_model.input}, outputs={"output": keras_model.output}
24-
)
25-
26-
signature_def_map = {saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature}
27-
28-
builder.add_meta_graph_and_variables(
29-
keras.backend.get_session(), [saved_model.tag_constants.SERVING], signature_def_map=signature_def_map
30-
)
31-
builder.save()
18+
def build_tf_model():
19+
keras_model = keras.models.load_model(keras_weight_path)
20+
21+
builder = saved_model.builder.SavedModelBuilder(output_path)
22+
signature = saved_model.signature_def_utils.predict_signature_def(
23+
inputs={"input": keras_model.input}, outputs={"output": keras_model.output}
24+
)
25+
26+
signature_def_map = {saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature}
27+
28+
builder.add_meta_graph_and_variables(
29+
keras.backend.get_session(), [saved_model.tag_constants.SERVING], signature_def_map=signature_def_map
30+
)
31+
builder.save()
32+
33+
try:
34+
# try to build the tf model with the keras import from tensorflow
35+
from tensorflow import keras
36+
build_tf_model()
37+
except Exception:
38+
# if the above fails try to export with the standalone keras
39+
import keras
40+
build_tf_model()
3241

3342
if zip_weights:
3443
zipped_model = f"{output_path}.zip"

0 commit comments

Comments
 (0)