|
2 | 2 | from bioimageio.core import export_resource_package |
3 | 3 |
|
4 | 4 | # test models for various frameworks |
5 | | -torch_models = ["unet2d_nuclei_broad_model", "unet2d_multi_tensor"] |
6 | | -torchscript_models = ["unet2d_nuclei_broad_model", "unet2d_multi_tensor"] |
7 | | -onnx_models = ["unet2d_nuclei_broad_model", "unet2d_multi_tensor"] |
| 5 | +torch_models = ["unet2d_fixed_shape", "unet2d_multi_tensor", "unet2d_nuclei_broad_model"] |
| 6 | +torchscript_models = ["unet2d_multi_tensor", "unet2d_nuclei_broad_model"] |
| 7 | +onnx_models = ["unet2d_multi_tensor", "unet2d_nuclei_broad_model"] |
8 | 8 | tensorflow1_models = ["FruNet_model"] |
9 | 9 | tensorflow2_models = [] |
10 | 10 | keras_models = ["FruNet_model"] |
11 | 11 | tensorflow_js_models = ["FruNet_model"] |
12 | 12 |
|
13 | 13 | model_sources = { |
| 14 | + "FruNet_model": "https://sandbox.zenodo.org/record/894498/files/rdf.yaml", |
| 15 | + # "FruNet_model": "https://raw.githubusercontent.com/deepimagej/models/master/fru-net_sev_segmentation/model.yaml", |
14 | 16 | "unet2d_nuclei_broad_model": ( |
15 | 17 | "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" |
16 | 18 | "unet2d_nuclei_broad/rdf.yaml" |
17 | 19 | ), |
| 20 | + "unet2d_fixed_shape": ( |
| 21 | + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" |
| 22 | + "unet2d_fixed_shape/rdf.yaml" |
| 23 | + ), |
18 | 24 | "unet2d_multi_tensor": ( |
19 | 25 | "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" |
20 | 26 | "unet2d_multi_tensor/rdf.yaml" |
21 | 27 | ), |
22 | | - "FruNet_model": "https://sandbox.zenodo.org/record/894498/files/rdf.yaml", |
23 | | - # "FruNet_model": "https://raw.githubusercontent.com/deepimagej/models/master/fru-net_sev_segmentation/model.yaml", |
24 | 28 | } |
25 | 29 |
|
26 | | -# set 'skip_<FRAMEWORK>' flags as global pytest variables, |
27 | | -# to deselect tests that require frameworks not available in current env |
28 | | -def pytest_configure(): |
29 | | - try: |
30 | | - import torch |
31 | | - except ImportError: |
32 | | - torch = None |
33 | | - pytest.skip_torch = torch is None |
34 | | - |
35 | | - try: |
36 | | - import onnxruntime |
37 | | - except ImportError: |
38 | | - onnxruntime = None |
39 | | - pytest.skip_onnx = onnxruntime is None |
| 30 | +try: |
| 31 | + import torch |
| 32 | +except ImportError: |
| 33 | + torch = None |
| 34 | +skip_torch = torch is None |
| 35 | + |
| 36 | +try: |
| 37 | + import onnxruntime |
| 38 | +except ImportError: |
| 39 | + onnxruntime = None |
| 40 | +skip_onnx = onnxruntime is None |
| 41 | + |
| 42 | +try: |
| 43 | + import tensorflow |
| 44 | + |
| 45 | + tf_major_version = int(tensorflow.__version__.split(".")[0]) |
| 46 | +except ImportError: |
| 47 | + tensorflow = None |
| 48 | + tf_major_version = None |
| 49 | + |
| 50 | +skip_tensorflow = tensorflow is None |
| 51 | +skip_tensorflow = True # todo: update FruNet and remove this |
| 52 | +skip_tensorflow_js = True # todo: update FruNet and figure out how to test tensorflow_js weights in python |
| 53 | + |
| 54 | +try: |
| 55 | + import keras |
| 56 | +except ImportError: |
| 57 | + keras = None |
| 58 | +skip_keras = keras is None |
| 59 | +skip_keras = True # FruNet requires update |
| 60 | + |
| 61 | +# load all model packages we need for testing |
| 62 | +load_model_packages = set() |
| 63 | +if not skip_torch: |
| 64 | + load_model_packages |= set(torch_models + torchscript_models) |
| 65 | + |
| 66 | +if not skip_onnx: |
| 67 | + load_model_packages |= set(onnx_models) |
| 68 | + |
| 69 | +if not skip_tensorflow: |
| 70 | + load_model_packages |= set(keras_models) |
| 71 | + load_model_packages |= set(tensorflow_js_models) |
| 72 | + if tf_major_version == 1: |
| 73 | + load_model_packages |= set(tensorflow1_models) |
| 74 | + elif tf_major_version == 2: |
| 75 | + load_model_packages |= set(tensorflow2_models) |
40 | 76 |
|
41 | | - try: |
42 | | - import tensorflow |
43 | | - |
44 | | - pytest.tf_major_version = int(tensorflow.__version__.split(".")[0]) |
45 | | - except ImportError: |
46 | | - tensorflow = None |
47 | | - pytest.skip_tensorflow = tensorflow is None |
48 | | - pytest.skip_tensorflow = True # todo: update FruNet and remove this |
49 | | - |
50 | | - try: |
51 | | - import keras |
52 | | - except ImportError: |
53 | | - keras = None |
54 | | - pytest.skip_keras = keras is None |
55 | | - pytest.skip_keras = True # FruNet requires update |
56 | | - |
57 | | - # load all model packages we need for testing |
58 | | - load_packages = {"unet2d_nuclei_broad_model"} # always load unet2d_nuclei_broad_model |
59 | | - if not pytest.skip_torch: |
60 | | - load_packages |= set(torch_models + torchscript_models) |
61 | | - |
62 | | - if not pytest.skip_onnx: |
63 | | - load_packages |= set(onnx_models) |
64 | | - |
65 | | - if not pytest.skip_tensorflow: |
66 | | - load_packages |= set(keras_models) |
67 | | - load_packages |= set(tensorflow_js_models) |
68 | | - if pytest.tf_major_version == 1: |
69 | | - load_packages |= set(tensorflow1_models) |
70 | | - elif pytest.tf_major_version == 2: |
71 | | - load_packages |= set(tensorflow2_models) |
72 | | - |
73 | | - pytest.model_packages = {name: export_resource_package(model_sources[name]) for name in load_packages} |
74 | 77 |
|
| 78 | +def pytest_configure(): |
75 | 79 |
|
76 | | -@pytest.fixture |
77 | | -def unet2d_nuclei_broad_model(): |
78 | | - return pytest.model_packages["unet2d_nuclei_broad_model"] |
| 80 | + # explicit skip flags needed for some tests |
| 81 | + pytest.skip_torch = skip_torch |
| 82 | + pytest.skip_onnx = skip_onnx |
79 | 83 |
|
| 84 | + # load all model packages used in tests |
| 85 | + pytest.model_packages = {name: export_resource_package(model_sources[name]) for name in load_model_packages} |
80 | 86 |
|
81 | | -@pytest.fixture |
82 | | -def unet2d_multi_tensor(): |
83 | | - return pytest.model_packages["unet2d_multi_tensor"] |
84 | 87 |
|
| 88 | +# |
| 89 | +# model groups of the form any_<weight format>_model that include all models providing a specific weight format |
| 90 | +# |
85 | 91 |
|
86 | | -@pytest.fixture |
87 | | -def FruNet_model(): |
88 | | - return pytest.model_packages["FruNet_model"] |
| 92 | +# written as model group to automatically skip on missing torch |
| 93 | +@pytest.fixture(params=[] if skip_torch else ["unet2d_nuclei_broad_model"]) |
| 94 | +def unet2d_nuclei_broad_model(request): |
| 95 | + return pytest.model_packages[request.param] |
89 | 96 |
|
90 | 97 |
|
91 | | -# |
92 | | -# model groups |
93 | | -# |
| 98 | +# written as model group to automatically skip on missing tensorflow 1 |
| 99 | +@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["FruNet_model"]) |
| 100 | +def FruNet_model(request): |
| 101 | + return pytest.model_packages[request.param] |
94 | 102 |
|
95 | 103 |
|
96 | | -@pytest.fixture(params=torch_models) |
| 104 | +@pytest.fixture(params=[] if skip_torch else torch_models) |
97 | 105 | def any_torch_model(request): |
98 | 106 | return pytest.model_packages[request.param] |
99 | 107 |
|
100 | 108 |
|
101 | | -@pytest.fixture(params=torchscript_models) |
| 109 | +@pytest.fixture(params=[] if skip_torch else torchscript_models) |
102 | 110 | def any_torchscript_model(request): |
103 | 111 | return pytest.model_packages[request.param] |
104 | 112 |
|
105 | 113 |
|
106 | | -@pytest.fixture(params=onnx_models) |
| 114 | +@pytest.fixture(params=[] if skip_onnx else onnx_models) |
107 | 115 | def any_onnx_model(request): |
108 | 116 | return pytest.model_packages[request.param] |
109 | 117 |
|
110 | 118 |
|
111 | | -@pytest.fixture(params=set(tensorflow1_models) | set(tensorflow2_models)) |
| 119 | +@pytest.fixture(params=[] if skip_tensorflow else (set(tensorflow1_models) | set(tensorflow2_models))) |
112 | 120 | def any_tensorflow_model(request): |
113 | 121 | name = request.param |
114 | | - if (pytest.tf_major_version == 1 and name in tensorflow1_models) or ( |
115 | | - pytest.tf_major_version == 2 and name in tensorflow2_models |
116 | | - ): |
| 122 | + if (tf_major_version == 1 and name in tensorflow1_models) or (tf_major_version == 2 and name in tensorflow2_models): |
117 | 123 | return pytest.model_packages[name] |
118 | 124 |
|
119 | 125 |
|
120 | | -@pytest.fixture(params=keras_models) |
| 126 | +@pytest.fixture(params=[] if skip_keras else keras_models) |
121 | 127 | def any_keras_model(request): |
122 | 128 | return pytest.model_packages[request.param] |
123 | 129 |
|
124 | 130 |
|
125 | | -@pytest.fixture(params=tensorflow_js_models) |
| 131 | +@pytest.fixture(params=[] if skip_tensorflow_js else tensorflow_js_models) |
126 | 132 | def any_tensorflow_js_model(request): |
127 | 133 | return pytest.model_packages[request.param] |
128 | 134 |
|
129 | 135 |
|
130 | | -@pytest.fixture(params=model_sources.keys()) |
| 136 | +# fixture to test with all models that should run in the current environment |
| 137 | +@pytest.fixture(params=load_model_packages) |
131 | 138 | def any_model(request): |
132 | 139 | return pytest.model_packages[request.param] |
| 140 | + |
| 141 | + |
| 142 | +# |
| 143 | +# temporary fixtures to test not with all, but only a manual selection of models |
| 144 | +# (models/functionality should be improved to get rid of this specific model group) |
| 145 | +# |
| 146 | +@pytest.fixture(params=[] if skip_torch else ["unet2d_nuclei_broad_model", "unet2d_fixed_shape"]) |
| 147 | +def unet2d_fixed_shape_or_not(request): |
| 148 | + return pytest.model_packages[request.param] |
| 149 | + |
| 150 | + |
| 151 | +@pytest.fixture(params=[] if skip_torch else ["unet2d_nuclei_broad_model", "unet2d_multi_tensor"]) |
| 152 | +def unet2d_multi_tensor_or_not(request): |
| 153 | + return pytest.model_packages[request.param] |
0 commit comments