Skip to content

Commit 70b7918

Browse files
Add keras test model
1 parent 1200b5a commit 70b7918

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

tests/build_spec/test_build_spec.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def _test_build_spec(
6767
output_path=out_path,
6868
add_deepimagej_config=add_deepimagej_config,
6969
)
70-
# TODO names
7170
if architecture is not None:
7271
kwargs["architecture"] = architecture
7372
if model_kwargs is not None:
@@ -134,5 +133,5 @@ def test_build_spec_deepimagej(unet2d_nuclei_broad_model, tmp_path):
134133
_test_build_spec(unet2d_nuclei_broad_model, tmp_path / "model.zip", "pytorch_script", add_deepimagej_config=True)
135134

136135

137-
# def test_build_spec_deepimagej_keras(unet2d_keras, tmp_path):
138-
# _test_build_spec(unet2d_keras, tmp_path / "model.zip", "pytorch_script", add_deepimagej_config=True)
136+
def test_build_spec_deepimagej_keras(unet2d_keras, tmp_path):
137+
_test_build_spec(unet2d_keras, tmp_path / "model.zip", "keras_hdf5", add_deepimagej_config=True)

tests/conftest.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
onnx_models = ["unet2d_multi_tensor", "unet2d_nuclei_broad_model", "hpa_densenet"]
1414
tensorflow1_models = ["stardist"]
1515
tensorflow2_models = []
16-
keras_models = []
16+
keras_models = ["unet2d_keras"]
1717
tensorflow_js_models = []
1818

1919
model_sources = {
20-
# TODO add unet2d_keras_tf from https://github.com/bioimage-io/spec-bioimage-io/pull/267
21-
# "unet2d_keras_tf": (""),
20+
"unet2d_keras": (
21+
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
22+
"unet2d_keras_tf/rdf.yaml"
23+
),
2224
"unet2d_nuclei_broad_model": (
2325
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
2426
"unet2d_nuclei_broad/rdf.yaml"
@@ -79,7 +81,6 @@
7981
except ImportError:
8082
keras = None
8183
skip_keras = keras is None
82-
skip_keras = True # TODO add unet2d_keras_tf to have a model for keras tests
8384

8485
# load all model packages we need for testing
8586
load_model_packages = set()
@@ -197,3 +198,10 @@ def unet2d_fixed_shape_or_not(request):
197198
)
198199
def unet2d_multi_tensor_or_not(request):
199200
return pytest.model_packages[request.param]
201+
202+
203+
@pytest.fixture(
204+
params=[] if skip_keras else ["unet2d_keras"]
205+
)
206+
def unet2d_keras(request):
207+
return pytest.model_packages[request.param]

0 commit comments

Comments
 (0)