File tree Expand file tree Collapse file tree 2 files changed +33
-3
lines changed
bioimageio/core/prediction_pipeline/_model_adapters Expand file tree Collapse file tree 2 files changed +33
-3
lines changed Original file line number Diff line number Diff line change 44# by default, we use the keras integrated with tensorflow
55try :
66 from tensorflow import keras
7+ import tensorflow as tf
8+ TF_VERSION = tf .__version__
79except Exception :
810 import keras
11+ TF_VERSION = None
912import xarray as xr
1013
1114from ._model_adapter import ModelAdapter
1215
1316
1417class KerasModelAdapter (ModelAdapter ):
1518 def _load (self , * , devices : Optional [Sequence [str ]] = None ) -> None :
19+ try :
20+ model_tf_version = self .bioimageio_model .weights [self .weight_format ].tensorflow_version .version
21+ except AttributeError :
22+ model_tf_version = None
23+
24+ if TF_VERSION is None or model_tf_version is None :
25+ warnings .warn ("Could not check tensorflow versions. The prediction results may be wrong." )
26+ elif tuple (model_tf_version [:2 ]) != tuple (map (int , TF_VERSION .split ("." )))[:2 ]:
27+ warnings .warn (
28+ f"Model tensorflow version { model_tf_version } does not match { TF_VERSION } ."
29+ "The prediction results may be wrong"
30+ )
31+
1632 # TODO keras device management
1733 if devices is not None :
1834 warnings .warn (f"Device management is not implemented for keras yet, ignoring the devices { devices } " )
Original file line number Diff line number Diff line change @@ -36,10 +36,24 @@ def _load_model(self, weight_file):
3636
3737 def _load (self , * , devices : Optional [List [str ]] = None ):
3838 try :
39- tf_version = self .bioimageio_model .weights [self .weight_format ].tensorflow_version .version
39+ model_tf_version = self .bioimageio_model .weights [self .weight_format ].tensorflow_version .version
4040 except AttributeError :
41- tf_version = (1 , 14 , 0 )
42- tf_major_ver = tf_version [0 ]
41+ model_tf_version = None
42+
43+ tf_version = tf .__version__
44+ tf_major_and_minor = tuple (map (int , tf_version .split ("." )))[:2 ]
45+ if model_tf_version is None :
46+ warnings .warn (
47+ "The model did not contain metadata about the tensorflow version used for training."
48+ f"Cannot check if it is compatible with tf { tf_version } . The prediction result may be wrong."
49+ )
50+ elif tuple (model_tf_version [:2 ]) != tf_major_and_minor :
51+ warnings .warn (
52+ f"Model tensorflow version { model_tf_version } does not match { tf_version } ."
53+ "The prediction results may be wrong"
54+ )
55+
56+ tf_major_ver = tf_major_and_minor [0 ]
4357 assert tf_major_ver in (1 , 2 )
4458 self .use_keras_api = tf_major_ver > 1 or self .weight_format == KerasModelAdapter .weight_format
4559
You can’t perform that action at this time.
0 commit comments