Skip to content

Commit a53daca

Browse files
Warn if current tensorflow version and model tensorflow version don't match
1 parent f0d9b8d commit a53daca

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

bioimageio/core/prediction_pipeline/_model_adapters/_keras_model_adapter.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,31 @@
44
# by default, we use the keras integrated with tensorflow
55
try:
66
from tensorflow import keras
7+
import tensorflow as tf
8+
TF_VERSION = tf.__version__
79
except Exception:
810
import keras
11+
TF_VERSION = None
912
import xarray as xr
1013

1114
from ._model_adapter import ModelAdapter
1215

1316

1417
class KerasModelAdapter(ModelAdapter):
1518
def _load(self, *, devices: Optional[Sequence[str]] = None) -> None:
19+
try:
20+
model_tf_version = self.bioimageio_model.weights[self.weight_format].tensorflow_version.version
21+
except AttributeError:
22+
model_tf_version = None
23+
24+
if TF_VERSION is None or model_tf_version is None:
25+
warnings.warn("Could not check tensorflow versions. The prediction results may be wrong.")
26+
elif tuple(model_tf_version[:2]) != tuple(map(int, TF_VERSION.split(".")))[:2]:
27+
warnings.warn(
28+
f"Model tensorflow version {model_tf_version} does not match {TF_VERSION}."
29+
"The prediction results may be wrong"
30+
)
31+
1632
# TODO keras device management
1733
if devices is not None:
1834
warnings.warn(f"Device management is not implemented for keras yet, ignoring the devices {devices}")

bioimageio/core/prediction_pipeline/_model_adapters/_tensorflow_model_adapter.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,24 @@ def _load_model(self, weight_file):
3636

3737
def _load(self, *, devices: Optional[List[str]] = None):
3838
try:
39-
tf_version = self.bioimageio_model.weights[self.weight_format].tensorflow_version.version
39+
model_tf_version = self.bioimageio_model.weights[self.weight_format].tensorflow_version.version
4040
except AttributeError:
41-
tf_version = (1, 14, 0)
42-
tf_major_ver = tf_version[0]
41+
model_tf_version = None
42+
43+
tf_version = tf.__version__
44+
tf_major_and_minor = tuple(map(int, tf_version.split(".")))[:2]
45+
if model_tf_version is None:
46+
warnings.warn(
47+
"The model did not contain metadata about the tensorflow version used for training."
48+
f"Cannot check if it is compatible with tf {tf_version}. The prediction result may be wrong."
49+
)
50+
elif tuple(model_tf_version[:2]) != tf_major_and_minor:
51+
warnings.warn(
52+
f"Model tensorflow version {model_tf_version} does not match {tf_version}."
53+
"The prediction results may be wrong"
54+
)
55+
56+
tf_major_ver = tf_major_and_minor[0]
4357
assert tf_major_ver in (1, 2)
4458
self.use_keras_api = tf_major_ver > 1 or self.weight_format == KerasModelAdapter.weight_format
4559

0 commit comments

Comments
 (0)