|
49 | 49 |
|
50 | 50 | skip_tensorflow = tensorflow is None |
51 | 51 | skip_tensorflow = True # todo: update FruNet and remove this |
| 52 | +skip_tensorflow_js = True # todo: update FruNet and figure out how to test tensorflow_js weights in python |
52 | 53 |
|
53 | 54 | try: |
54 | 55 | import keras |
|
74 | 75 | load_model_packages |= set(tensorflow2_models) |
75 | 76 |
|
76 | 77 |
|
77 | | -# set 'skip_<FRAMEWORK>' flags as global pytest variables, |
78 | | -# to deselect tests that require frameworks not available in current env |
79 | 78 | def pytest_configure(): |
80 | | - pytest.skip_torch = skip_torch |
81 | | - pytest.skip_onnx = skip_onnx |
82 | | - pytest.skip_tensorflow = skip_tensorflow |
83 | | - pytest.tf_major_version = tf_major_version |
84 | | - pytest.skip_keras = skip_keras |
85 | | - |
86 | | - pytest.model_packages = { |
87 | | - name: export_resource_package(model_sources[name]) |
88 | | - for name in (load_model_packages | {"unet2d_nuclei_broad_model"}) # always load unet2d_nuclei_broad_model |
89 | | - } |
90 | | - |
91 | 79 |
|
92 | | -@pytest.fixture |
93 | | -def unet2d_nuclei_broad_model(): |
94 | | - return pytest.model_packages["unet2d_nuclei_broad_model"] |
| 80 | + # explicit skip flag needed for pytorch to onnx converter test |
| 81 | + pytest.skip_onnx = skip_onnx |
95 | 82 |
|
| 83 | + # load all model packages used in tests |
| 84 | + pytest.model_packages = {name: export_resource_package(model_sources[name]) for name in load_model_packages} |
96 | 85 |
|
97 | | -@pytest.fixture |
98 | | -def unet2d_multi_tensor(): |
99 | | - return pytest.model_packages["unet2d_multi_tensor"] |
100 | 86 |
|
| 87 | +# |
| 88 | +# model groups of the form any_<weight format>_model that include all models providing a specific weight format |
| 89 | +# |
101 | 90 |
|
102 | | -@pytest.fixture |
103 | | -def FruNet_model(): |
104 | | - return pytest.model_packages["FruNet_model"] |
| 91 | +# written as model group to automatically skip on missing torch |
| 92 | +@pytest.fixture(params=[] if skip_torch else ["unet2d_nuclei_broad_model"]) |
| 93 | +def unet2d_nuclei_broad_model(request): |
| 94 | + return pytest.model_packages[request.param] |
105 | 95 |
|
106 | 96 |
|
107 | | -# |
108 | | -# model groups |
109 | | -# |
| 97 | +# written as model group to automatically skip on missing tensorflow 1 |
| 98 | +@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["FruNet_model"]) |
| 99 | +def FruNet_model(request): |
| 100 | + return pytest.model_packages[request.param] |
110 | 101 |
|
111 | 102 |
|
112 | | -@pytest.fixture(params=torch_models) |
| 103 | +@pytest.fixture(params=[] if skip_torch else torch_models) |
113 | 104 | def any_torch_model(request): |
114 | 105 | return pytest.model_packages[request.param] |
115 | 106 |
|
116 | 107 |
|
117 | | -@pytest.fixture(params=torchscript_models) |
| 108 | +@pytest.fixture(params=[] if skip_torch else torchscript_models) |
118 | 109 | def any_torchscript_model(request): |
119 | 110 | return pytest.model_packages[request.param] |
120 | 111 |
|
121 | 112 |
|
122 | | -@pytest.fixture(params=onnx_models) |
| 113 | +@pytest.fixture(params=[] if skip_onnx else onnx_models) |
123 | 114 | def any_onnx_model(request): |
124 | 115 | return pytest.model_packages[request.param] |
125 | 116 |
|
126 | 117 |
|
127 | | -@pytest.fixture(params=set(tensorflow1_models) | set(tensorflow2_models)) |
| 118 | +@pytest.fixture(params=[] if skip_tensorflow else (set(tensorflow1_models) | set(tensorflow2_models))) |
128 | 119 | def any_tensorflow_model(request): |
129 | 120 | name = request.param |
130 | | - if (pytest.tf_major_version == 1 and name in tensorflow1_models) or ( |
131 | | - pytest.tf_major_version == 2 and name in tensorflow2_models |
132 | | - ): |
| 121 | + if (tf_major_version == 1 and name in tensorflow1_models) or (tf_major_version == 2 and name in tensorflow2_models): |
133 | 122 | return pytest.model_packages[name] |
134 | 123 |
|
135 | 124 |
|
136 | | -@pytest.fixture(params=keras_models) |
| 125 | +@pytest.fixture(params=[] if skip_keras else keras_models) |
137 | 126 | def any_keras_model(request): |
138 | 127 | return pytest.model_packages[request.param] |
139 | 128 |
|
140 | 129 |
|
141 | | -@pytest.fixture(params=tensorflow_js_models) |
| 130 | +@pytest.fixture(params=[] if skip_tensorflow_js else tensorflow_js_models) |
142 | 131 | def any_tensorflow_js_model(request): |
143 | 132 | return pytest.model_packages[request.param] |
144 | 133 |
|
145 | 134 |
|
| 135 | +# fixture to test with all models that should run in the current environment |
146 | 136 | @pytest.fixture(params=load_model_packages) |
147 | 137 | def any_model(request): |
148 | 138 | return pytest.model_packages[request.param] |
149 | 139 |
|
150 | 140 |
|
151 | | -@pytest.fixture(params=["unet2d_nuclei_broad_model", "unet2d_fixed_shape"]) |
| 141 | +# temporary fixture to test not with all, but only a manual selection of models |
| 142 | +# (models/functionality should be improved to get rid of this specific model group) |
| 143 | +@pytest.fixture(params=[] if skip_torch else ["unet2d_nuclei_broad_model", "unet2d_fixed_shape"]) |
152 | 144 | def unet2d_fixed_shape_or_not(request): |
153 | 145 | return pytest.model_packages[request.param] |
0 commit comments