|
11 | 11 | torch_models_pre_3_10 = ["unet2d_fixed_shape", "unet2d_multi_tensor", "unet2d_nuclei_broad_model"] |
12 | 12 | torchscript_models = ["unet2d_multi_tensor", "unet2d_nuclei_broad_model"] |
13 | 13 | onnx_models = ["unet2d_multi_tensor", "unet2d_nuclei_broad_model", "hpa_densenet"] |
14 | | -tensorflow1_models = ["FruNet_model", "stardist"] |
| 14 | +tensorflow1_models = ["stardist"] |
15 | 15 | tensorflow2_models = [] |
16 | | -keras_models = ["FruNet_model"] |
17 | | -tensorflow_js_models = ["FruNet_model"] |
| 16 | +keras_models = [] |
| 17 | +tensorflow_js_models = [] |
18 | 18 |
|
19 | 19 | model_sources = { |
20 | | - "FruNet_model": "https://sandbox.zenodo.org/record/894498/files/rdf.yaml", |
21 | | - # "FruNet_model": "https://raw.githubusercontent.com/deepimagej/models/master/fru-net_sev_segmentation/model.yaml", |
| 20 | + # TODO add unet2d_keras_tf from https://github.com/bioimage-io/spec-bioimage-io/pull/267 |
| 21 | + # "unet2d_keras_tf": (""), |
22 | 22 | "unet2d_nuclei_broad_model": ( |
23 | 23 | "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" |
24 | 24 | "unet2d_nuclei_broad/rdf.yaml" |
|
35 | 35 | "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/hpa-densenet/rdf.yaml" |
36 | 36 | ), |
37 | 37 | "stardist": ( |
38 | | - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/stardist_example_model/rdf.yaml" |
| 38 | + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models" |
| 39 | + "/stardist_example_model/rdf.yaml" |
39 | 40 | ), |
40 | 41 | "stardist_wrong_shape": ( |
41 | | - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/stardist_example_model/rdf_wrong_shape.yaml" |
| 42 | + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" |
| 43 | + "stardist_example_model/rdf_wrong_shape.yaml" |
42 | 44 | ), |
43 | 45 | } |
44 | 46 |
|
|
50 | 52 | except ImportError: |
51 | 53 | torch = None |
52 | 54 | torch_version = None |
53 | | - |
54 | 55 | skip_torch = torch is None |
55 | 56 |
|
56 | 57 | try: |
|
66 | 67 | except ImportError: |
67 | 68 | tensorflow = None |
68 | 69 | tf_major_version = None |
69 | | - |
70 | 70 | skip_tensorflow = tensorflow is None |
71 | | -skip_tensorflow = True # todo: update FruNet and remove this |
72 | | -skip_tensorflow_js = True # todo: update FruNet and figure out how to test tensorflow_js weights in python |
| 71 | +skip_tensorflow_js = True # TODO: add a tensorflow_js example model |
73 | 72 |
|
74 | 73 | try: |
75 | 74 | import keras |
76 | 75 | except ImportError: |
77 | 76 | keras = None |
78 | 77 | skip_keras = keras is None |
79 | | -skip_keras = True # FruNet requires update |
| 78 | +skip_keras = True # TODO add unet2d_keras_tf to have a model for keras tests |
80 | 79 |
|
81 | 80 | # load all model packages we need for testing |
82 | 81 | load_model_packages = set() |
@@ -120,14 +119,14 @@ def unet2d_nuclei_broad_model(request): |
120 | 119 |
|
121 | 120 |
|
122 | 121 | # written as model group to automatically skip on missing tensorflow 1 |
123 | | -@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["FruNet_model"]) |
124 | | -def FruNet_model(request): |
| 122 | +@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist_wrong_shape"]) |
| 123 | +def stardist_wrong_shape(request): |
125 | 124 | return pytest.model_packages[request.param] |
126 | 125 |
|
127 | 126 |
|
128 | 127 | # written as model group to automatically skip on missing tensorflow 1 |
129 | | -@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist_wrong_shape"]) |
130 | | -def stardist_wrong_shape(request): |
| 128 | +@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist"]) |
| 129 | +def stardist(request): |
131 | 130 | return pytest.model_packages[request.param] |
132 | 131 |
|
133 | 132 |
|
@@ -164,7 +163,9 @@ def any_tensorflow_js_model(request): |
164 | 163 |
|
165 | 164 |
|
166 | 165 | # fixture to test with all models that should run in the current environment |
167 | | -@pytest.fixture(params=load_model_packages) |
| 166 | +# we exclude stardist_wrong_shape here because it is not a valid model |
| 167 | +# and included only to test that validation for this model fails |
| 168 | +@pytest.fixture(params=load_model_packages - {"stardist_wrong_shape"}) |
168 | 169 | def any_model(request): |
169 | 170 | return pytest.model_packages[request.param] |
170 | 171 |
|
|
0 commit comments