|
96 | 96 | skip_tensorflow = tensorflow is None |
97 | 97 | skip_tensorflow_js = True # TODO: add a tensorflow_js example model |
98 | 98 |
|
99 | | -try: |
100 | | - import keras |
101 | | -except ImportError: |
102 | | - keras = None |
103 | | -skip_keras = keras is None |
104 | | - |
105 | 99 | # load all model packages we need for testing |
106 | 100 | load_model_packages = set() |
107 | 101 | if not skip_torch: |
@@ -152,14 +146,12 @@ def any_onnx_model(request): |
152 | 146 | return pytest.model_packages[request.param] |
153 | 147 |
|
154 | 148 |
|
155 | | -@pytest.fixture(params=[] if skip_tensorflow else (set(tensorflow1_models) | set(tensorflow2_models))) |
| 149 | +@pytest.fixture(params=[] if skip_tensorflow else tensorflow1_models if tf_major_version == 1 else tensorflow2_models) |
156 | 150 | def any_tensorflow_model(request): |
157 | | - name = request.param |
158 | | - if (tf_major_version == 1 and name in tensorflow1_models) or (tf_major_version == 2 and name in tensorflow2_models): |
159 | | - return pytest.model_packages[name] |
| 151 | + return pytest.model_packages[request.param] |
160 | 152 |
|
161 | 153 |
|
162 | | -@pytest.fixture(params=[] if skip_keras else (set(keras_tf1_models) | set(keras_tf2_models))) |
| 154 | +@pytest.fixture(params=[] if skip_tensorflow else keras_tf1_models if tf_major_version == 1 else keras_tf2_models) |
163 | 155 | def any_keras_model(request): |
164 | 156 | return pytest.model_packages[request.param] |
165 | 157 |
|
@@ -194,7 +186,7 @@ def unet2d_multi_tensor_or_not(request): |
194 | 186 | return pytest.model_packages[request.param] |
195 | 187 |
|
196 | 188 |
|
197 | | -@pytest.fixture(params=[] if skip_keras else ["unet2d_keras" if tf_major_version == 1 else "unet2d_keras_tf2"]) |
| 189 | +@pytest.fixture(params=[] if skip_tensorflow else ["unet2d_keras" if tf_major_version == 1 else "unet2d_keras_tf2"]) |
198 | 190 | def unet2d_keras(request): |
199 | 191 | return pytest.model_packages[request.param] |
200 | 192 |
|
|
0 commit comments