1919torchscript_models = ["unet2d_multi_tensor" , "unet2d_nuclei_broad_model" ]
2020onnx_models = ["unet2d_multi_tensor" , "unet2d_nuclei_broad_model" , "hpa_densenet" ]
2121tensorflow1_models = ["stardist" ]
22- tensorflow2_models = []
23- keras_models = ["unet2d_keras" ]
22+ tensorflow2_models = ["unet2d_keras_tf2" ]
23+ keras_tf1_models = ["unet2d_keras" ]
24+ keras_tf2_models = ["unet2d_keras_tf2" ]
2425tensorflow_js_models = []
2526
2627model_sources = {
2728 "unet2d_keras" : (
2829 "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
2930 "unet2d_keras_tf/rdf.yaml"
3031 ),
32+ "unet2d_keras_tf2" : (
33+ "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
34+ "unet2d_keras_tf2/rdf.yaml"
35+ ),
3136 "unet2d_nuclei_broad_model" : (
3237 "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
3338 "unet2d_nuclei_broad/rdf.yaml"
9196skip_tensorflow = tensorflow is None
9297skip_tensorflow_js = True # TODO: add a tensorflow_js example model
9398
94- try :
95- import keras
96- except ImportError :
97- keras = None
98- skip_keras = keras is None
99-
10099# load all model packages we need for testing
101100load_model_packages = set ()
102101if not skip_torch :
106105 load_model_packages |= set (onnx_models )
107106
108107if not skip_tensorflow :
109- load_model_packages |= set (keras_models )
110108 load_model_packages |= set (tensorflow_js_models )
111109 if tf_major_version == 1 :
110+ load_model_packages |= set (keras_tf1_models )
112111 load_model_packages |= set (tensorflow1_models )
113112 load_model_packages .add ("stardist_wrong_shape" )
114113 load_model_packages .add ("stardist_wrong_shape2" )
115114 elif tf_major_version == 2 :
115+ load_model_packages |= set (keras_tf2_models )
116116 load_model_packages |= set (tensorflow2_models )
117117
118118
@@ -146,14 +146,12 @@ def any_onnx_model(request):
146146 return pytest .model_packages [request .param ]
147147
148148
149- @pytest .fixture (params = [] if skip_tensorflow else ( set ( tensorflow1_models ) | set ( tensorflow2_models )) )
149+ @pytest .fixture (params = [] if skip_tensorflow else tensorflow1_models if tf_major_version == 1 else tensorflow2_models )
150150def any_tensorflow_model (request ):
151- name = request .param
152- if (tf_major_version == 1 and name in tensorflow1_models ) or (tf_major_version == 2 and name in tensorflow2_models ):
153- return pytest .model_packages [name ]
151+ return pytest .model_packages [request .param ]
154152
155153
156- @pytest .fixture (params = [] if skip_keras else keras_models )
154+ @pytest .fixture (params = [] if skip_tensorflow else keras_tf1_models if tf_major_version == 1 else keras_tf2_models )
157155def any_keras_model (request ):
158156 return pytest .model_packages [request .param ]
159157
@@ -178,21 +176,17 @@ def any_model(request):
178176#
179177
180178
181- @pytest .fixture (
182- params = [] if skip_torch else ["unet2d_nuclei_broad_model" , "unet2d_fixed_shape" ]
183- )
179+ @pytest .fixture (params = [] if skip_torch else ["unet2d_nuclei_broad_model" , "unet2d_fixed_shape" ])
184180def unet2d_fixed_shape_or_not (request ):
185181 return pytest .model_packages [request .param ]
186182
187183
188- @pytest .fixture (
189- params = [] if skip_torch else ["unet2d_nuclei_broad_model" , "unet2d_multi_tensor" ]
190- )
184+ @pytest .fixture (params = [] if skip_torch else ["unet2d_nuclei_broad_model" , "unet2d_multi_tensor" ])
191185def unet2d_multi_tensor_or_not (request ):
192186 return pytest .model_packages [request .param ]
193187
194188
195- @pytest .fixture (params = [] if skip_keras else ["unet2d_keras" ])
189+ @pytest .fixture (params = [] if skip_tensorflow else ["unet2d_keras" if tf_major_version == 1 else "unet2d_keras_tf2 " ])
196190def unet2d_keras (request ):
197191 return pytest .model_packages [request .param ]
198192
0 commit comments