Skip to content

Commit 31ff72f

Browse files
committed
clean up keras/tensorflow fixtures
1 parent 9e37716 commit 31ff72f

File tree

1 file changed

+4
-12
lines changed

1 file changed

+4
-12
lines changed

tests/conftest.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,6 @@
9696
skip_tensorflow = tensorflow is None
9797
skip_tensorflow_js = True # TODO: add a tensorflow_js example model
9898

99-
try:
100-
import keras
101-
except ImportError:
102-
keras = None
103-
skip_keras = keras is None
104-
10599
# load all model packages we need for testing
106100
load_model_packages = set()
107101
if not skip_torch:
@@ -152,14 +146,12 @@ def any_onnx_model(request):
152146
return pytest.model_packages[request.param]
153147

154148

155-
@pytest.fixture(params=[] if skip_tensorflow else (set(tensorflow1_models) | set(tensorflow2_models)))
149+
@pytest.fixture(params=[] if skip_tensorflow else tensorflow1_models if tf_major_version == 1 else tensorflow2_models)
156150
def any_tensorflow_model(request):
157-
name = request.param
158-
if (tf_major_version == 1 and name in tensorflow1_models) or (tf_major_version == 2 and name in tensorflow2_models):
159-
return pytest.model_packages[name]
151+
return pytest.model_packages[request.param]
160152

161153

162-
@pytest.fixture(params=[] if skip_keras else (set(keras_tf1_models) | set(keras_tf2_models)))
154+
@pytest.fixture(params=[] if skip_tensorflow else keras_tf1_models if tf_major_version == 1 else keras_tf2_models)
163155
def any_keras_model(request):
164156
return pytest.model_packages[request.param]
165157

@@ -194,7 +186,7 @@ def unet2d_multi_tensor_or_not(request):
194186
return pytest.model_packages[request.param]
195187

196188

197-
@pytest.fixture(params=[] if skip_keras else ["unet2d_keras" if tf_major_version == 1 else "unet2d_keras_tf2"])
189+
@pytest.fixture(params=[] if skip_tensorflow else ["unet2d_keras" if tf_major_version == 1 else "unet2d_keras_tf2"])
198190
def unet2d_keras(request):
199191
return pytest.model_packages[request.param]
200192

0 commit comments

Comments
 (0)