Skip to content

Commit cf2ee8d

Browse files
committed
fix tensorflow_version arg for tests
1 parent 31ff72f commit cf2ee8d

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

tests/build_spec/test_build_spec.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
from marshmallow import missing
24

35
import bioimageio.spec as spec
@@ -6,6 +8,13 @@
68
from bioimageio.core.resource_io.utils import resolve_source
79
from bioimageio.core.resource_tests import test_model as _test_model
810

11+
try:
12+
import tensorflow
13+
except ImportError:
14+
tf_version = None
15+
else:
16+
tf_version: Optional[str] = ".".join(tensorflow.__version__.split(".")[:2])
17+
918

1019
def _test_build_spec(
1120
spec_path,
@@ -175,18 +184,18 @@ def test_build_spec_onnx(any_onnx_model, tmp_path):
175184

176185
def test_build_spec_keras(any_keras_model, tmp_path):
177186
_test_build_spec(
178-
any_keras_model, tmp_path / "model.zip", "keras_hdf5", tensorflow_version="1.12"
187+
any_keras_model, tmp_path / "model.zip", "keras_hdf5", tensorflow_version=tf_version
179188
) # todo: keras for tf 2??
180189

181190

182191
def test_build_spec_tf(any_tensorflow_model, tmp_path):
183192
_test_build_spec(
184-
any_tensorflow_model, tmp_path / "model.zip", "tensorflow_saved_model_bundle", tensorflow_version="1.12"
193+
any_tensorflow_model, tmp_path / "model.zip", "tensorflow_saved_model_bundle", tensorflow_version=tf_version
185194
) # check tf version
186195

187196

188197
def test_build_spec_tfjs(any_tensorflow_js_model, tmp_path):
189-
_test_build_spec(any_tensorflow_js_model, tmp_path / "model.zip", "tensorflow_js", tensorflow_version="1.12")
198+
_test_build_spec(any_tensorflow_js_model, tmp_path / "model.zip", "tensorflow_js", tensorflow_version=tf_version)
190199

191200

192201
def test_build_spec_deepimagej(unet2d_nuclei_broad_model, tmp_path):
@@ -220,7 +229,7 @@ def test_build_spec_parent2(unet2d_nuclei_broad_model, tmp_path):
220229

221230
def test_build_spec_deepimagej_keras(unet2d_keras, tmp_path):
222231
_test_build_spec(
223-
unet2d_keras, tmp_path / "model.zip", "keras_hdf5", add_deepimagej_config=True, tensorflow_version="1.12"
232+
unet2d_keras, tmp_path / "model.zip", "keras_hdf5", add_deepimagej_config=True, tensorflow_version=tf_version
224233
)
225234

226235

0 commit comments

Comments
 (0)