|
8 | 8 | from bioimageio.core import load_resource_description |
9 | 9 |
|
10 | 10 | import tensorflow |
11 | | -from tensorflow import keras |
| 11 | +from tensorflow import saved_model |
12 | 12 |
|
13 | 13 |
|
14 | 14 | # adapted from |
15 | 15 | # https://github.com/deepimagej/pydeepimagej/blob/master/pydeepimagej/yaml/create_config.py#L236 |
16 | 16 | def _convert_tf1(keras_weight_path, output_path, zip_weights): |
17 | | - from tensorflow import saved_model |
18 | 17 |
|
19 | | - keras_model = keras.models.load_model(keras_weight_path) |
20 | | - |
21 | | - builder = saved_model.builder.SavedModelBuilder(output_path) |
22 | | - signature = saved_model.signature_def_utils.predict_signature_def( |
23 | | - inputs={"input": keras_model.input}, outputs={"output": keras_model.output} |
24 | | - ) |
25 | | - |
26 | | - signature_def_map = {saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature} |
27 | | - |
28 | | - builder.add_meta_graph_and_variables( |
29 | | - keras.backend.get_session(), [saved_model.tag_constants.SERVING], signature_def_map=signature_def_map |
30 | | - ) |
31 | | - builder.save() |
| 18 | + def build_tf_model(): |
| 19 | + keras_model = keras.models.load_model(keras_weight_path) |
| 20 | + |
| 21 | + builder = saved_model.builder.SavedModelBuilder(output_path) |
| 22 | + signature = saved_model.signature_def_utils.predict_signature_def( |
| 23 | + inputs={"input": keras_model.input}, outputs={"output": keras_model.output} |
| 24 | + ) |
| 25 | + |
| 26 | + signature_def_map = {saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature} |
| 27 | + |
| 28 | + builder.add_meta_graph_and_variables( |
| 29 | + keras.backend.get_session(), [saved_model.tag_constants.SERVING], signature_def_map=signature_def_map |
| 30 | + ) |
| 31 | + builder.save() |
| 32 | + |
| 33 | + try: |
| 34 | + # try to build the tf model with the keras import from tensorflow |
| 35 | + from tensorflow import keras |
| 36 | + build_tf_model() |
| 37 | + except Exception: |
| 38 | + # if the above fails try to export with the standalone keras |
| 39 | + import keras |
| 40 | + build_tf_model() |
32 | 41 |
|
33 | 42 | if zip_weights: |
34 | 43 | zipped_model = f"{output_path}.zip" |
|
0 commit comments