1- import warnings
21import zipfile
32from typing import List , Literal , Optional , Sequence , Union
43
54import numpy as np
5+ from loguru import logger
66
77from bioimageio .spec .common import FileSource
88from bioimageio .spec .model import v0_4 , v0_5
@@ -46,19 +46,19 @@ def __init__(
4646 )
4747 model_tf_version = weights .tensorflow_version
4848 if model_tf_version is None :
49- warnings . warn (
49+ logger . warning (
5050 "The model does not specify the tensorflow version."
5151 + f"Cannot check if it is compatible with intalled tensorflow { tf_version } ."
5252 )
5353 elif model_tf_version > tf_version :
54- warnings . warn (
54+ logger . warning (
5555 f"The model specifies a newer tensorflow version than installed: { model_tf_version } > { tf_version } ."
5656 )
5757 elif (model_tf_version .major , model_tf_version .minor ) != (
5858 tf_version .major ,
5959 tf_version .minor ,
6060 ):
61- warnings . warn (
61+ logger . warning (
6262 "The tensorflow version specified by the model does not match the installed: "
6363 + f"{ model_tf_version } != { tf_version } ."
6464 )
@@ -70,7 +70,7 @@ def __init__(
7070
7171 # TODO tf device management
7272 if devices is not None :
73- warnings . warn (
73+ logger . warning (
7474 f"Device management is not implemented for tensorflow yet, ignoring the devices { devices } "
7575 )
7676
@@ -98,9 +98,20 @@ def _get_network( # pyright: ignore[reportUnknownParameterType]
9898 weight_file = self .require_unzipped (weight_file )
9999 assert tf is not None
100100 if self .use_keras_api :
101- return tf .keras .models .load_model (
102- weight_file , compile = False
103- ) # pyright: ignore[reportUnknownVariableType]
101+ try :
102+ return tf .keras .layers .TFSMLayer (
103+ weight_file , call_endpoint = "serve"
104+ ) # pyright: ignore[reportUnknownVariableType]
105+ except Exception as e :
106+ try :
107+ return tf .keras .layers .TFSMLayer (
108+ weight_file , call_endpoint = "serving_default"
109+ ) # pyright: ignore[reportUnknownVariableType]
110+ except Exception as ee :
111+ logger .opt (exception = ee ).info (
112+ "keras.layers.TFSMLayer error for alternative call_endpoint='serving_default'"
113+ )
114+ raise e
104115 else :
105116 # NOTE in tf1 the model needs to be loaded inside of the session, so we cannot preload the model
106117 return str (weight_file )
@@ -189,24 +200,15 @@ def _forward_keras( # pyright: ignore[reportUnknownParameterType]
189200 None if ipt is None else tf .convert_to_tensor (ipt ) for ipt in input_tensors
190201 ]
191202
192- try :
193- result = ( # pyright: ignore[reportUnknownVariableType]
194- self ._network .forward (* tf_tensor )
195- )
196- except AttributeError :
197- result = ( # pyright: ignore[reportUnknownVariableType]
198- self ._network .predict (* tf_tensor )
199- )
203+ result = self ._network (* tf_tensor ) # pyright: ignore[reportUnknownVariableType]
200204
201- if not isinstance (result , (tuple , list )):
202- result = [result ] # pyright: ignore[reportUnknownVariableType]
205+ assert isinstance (result , dict )
206+
207+ # TODO: Use RDF's `outputs[i].id` here
208+ result = list (result .values ())
203209
204210 return [ # pyright: ignore[reportUnknownVariableType]
205- (
206- None
207- if r is None
208- else r if isinstance (r , np .ndarray ) else tf .make_ndarray (r )
209- )
211+ (None if r is None else r if isinstance (r , np .ndarray ) else r .numpy ())
210212 for r in result # pyright: ignore[reportUnknownVariableType]
211213 ]
212214
@@ -230,7 +232,7 @@ def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]:
230232 ]
231233
232234 def unload (self ) -> None :
233- warnings . warn (
235+ logger . warning (
234236 "Device management is not implemented for keras yet, cannot unload model"
235237 )
236238
0 commit comments