Skip to content

Commit 31e54cb

Browse files
committed
replace skip_<weith format> flags by model groups
1 parent fb20d9f commit 31e54cb

File tree

8 files changed

+33
-78
lines changed

8 files changed

+33
-78
lines changed

tests/build_spec/test_build_spec.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import os
22

3-
import pytest
4-
5-
import bioimageio.spec as spec
63
from marshmallow import missing
74

5+
import bioimageio.spec as spec
86
from bioimageio.core.resource_io.io_ import load_raw_resource_description
97

108

@@ -56,33 +54,27 @@ def _test_build_spec(path, weight_type, tensorflow_version=None):
5654
spec.model.schema.Model().dump(raw_model)
5755

5856

59-
@pytest.mark.skipif(pytest.skip_torch, reason="requires torch")
6057
def test_build_spec_pytorch(any_torch_model):
6158
_test_build_spec(any_torch_model, "pytorch_state_dict")
6259

6360

64-
@pytest.mark.skipif(pytest.skip_torch, reason="requires torch")
6561
def test_build_spec_torchscript(any_torchscript_model):
6662
_test_build_spec(any_torchscript_model, "pytorch_script")
6763

6864

69-
@pytest.mark.skipif(pytest.skip_onnx, reason="requires onnx")
7065
def test_build_spec_onnx(any_onnx_model):
7166
_test_build_spec(any_onnx_model, "onnx")
7267

7368

74-
@pytest.mark.skipif(pytest.skip_tensorflow or pytest.tf_major_version != 1, reason="requires tensorflow 1")
7569
def test_build_spec_keras(any_keras_model):
7670
_test_build_spec(any_keras_model, "keras_hdf5", tensorflow_version="1.12") # todo: keras for tf 2??
7771

7872

79-
@pytest.mark.skipif(pytest.skip_tensorflow, reason="requires tensorflow")
8073
def test_build_spec_tf(any_tensorflow_model):
8174
_test_build_spec(
8275
any_tensorflow_model, "tensorflow_saved_model_bundle", tensorflow_version="1.12"
8376
) # check tf version
8477

8578

86-
@pytest.mark.skipif(pytest.skip_tensorflow, reason="requires tensorflow")
8779
def test_build_spec_tfjs(any_tensorflow_js_model):
8880
_test_build_spec(any_tensorflow_js_model, "tensorflow_js", tensorflow_version="1.12")

tests/conftest.py

Lines changed: 27 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949

5050
skip_tensorflow = tensorflow is None
5151
skip_tensorflow = True # todo: update FruNet and remove this
52+
skip_tensorflow_js = True # todo: update FruNet and figure out how to test tensorflow_js weights in python
5253

5354
try:
5455
import keras
@@ -74,80 +75,71 @@
7475
load_model_packages |= set(tensorflow2_models)
7576

7677

77-
# set 'skip_<FRAMEWORK>' flags as global pytest variables,
78-
# to deselect tests that require frameworks not available in current env
7978
def pytest_configure():
80-
pytest.skip_torch = skip_torch
81-
pytest.skip_onnx = skip_onnx
82-
pytest.skip_tensorflow = skip_tensorflow
83-
pytest.tf_major_version = tf_major_version
84-
pytest.skip_keras = skip_keras
85-
86-
pytest.model_packages = {
87-
name: export_resource_package(model_sources[name])
88-
for name in (load_model_packages | {"unet2d_nuclei_broad_model"}) # always load unet2d_nuclei_broad_model
89-
}
90-
9179

92-
@pytest.fixture
93-
def unet2d_nuclei_broad_model():
94-
return pytest.model_packages["unet2d_nuclei_broad_model"]
80+
# explicit skip flag needed for pytorch to onnx converter test
81+
pytest.skip_onnx = skip_onnx
9582

83+
# load all model packages used in tests
84+
pytest.model_packages = {name: export_resource_package(model_sources[name]) for name in load_model_packages}
9685

97-
@pytest.fixture
98-
def unet2d_multi_tensor():
99-
return pytest.model_packages["unet2d_multi_tensor"]
10086

87+
#
88+
# model groups of the form any_<weight format>_model that include all models providing a specific weight format
89+
#
10190

102-
@pytest.fixture
103-
def FruNet_model():
104-
return pytest.model_packages["FruNet_model"]
91+
# written as model group to automatically skip on missing torch
92+
@pytest.fixture(params=[] if skip_torch else ["unet2d_nuclei_broad_model"])
93+
def unet2d_nuclei_broad_model(request):
94+
return pytest.model_packages[request.param]
10595

10696

107-
#
108-
# model groups
109-
#
97+
# written as model group to automatically skip on missing tensorflow 1
98+
@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["FruNet_model"])
99+
def FruNet_model(request):
100+
return pytest.model_packages[request.param]
110101

111102

112-
@pytest.fixture(params=torch_models)
103+
@pytest.fixture(params=[] if skip_torch else torch_models)
113104
def any_torch_model(request):
114105
return pytest.model_packages[request.param]
115106

116107

117-
@pytest.fixture(params=torchscript_models)
108+
@pytest.fixture(params=[] if skip_torch else torchscript_models)
118109
def any_torchscript_model(request):
119110
return pytest.model_packages[request.param]
120111

121112

122-
@pytest.fixture(params=onnx_models)
113+
@pytest.fixture(params=[] if skip_onnx else onnx_models)
123114
def any_onnx_model(request):
124115
return pytest.model_packages[request.param]
125116

126117

127-
@pytest.fixture(params=set(tensorflow1_models) | set(tensorflow2_models))
118+
@pytest.fixture(params=[] if skip_tensorflow else (set(tensorflow1_models) | set(tensorflow2_models)))
128119
def any_tensorflow_model(request):
129120
name = request.param
130-
if (pytest.tf_major_version == 1 and name in tensorflow1_models) or (
131-
pytest.tf_major_version == 2 and name in tensorflow2_models
132-
):
121+
if (tf_major_version == 1 and name in tensorflow1_models) or (tf_major_version == 2 and name in tensorflow2_models):
133122
return pytest.model_packages[name]
134123

135124

136-
@pytest.fixture(params=keras_models)
125+
@pytest.fixture(params=[] if skip_keras else keras_models)
137126
def any_keras_model(request):
138127
return pytest.model_packages[request.param]
139128

140129

141-
@pytest.fixture(params=tensorflow_js_models)
130+
@pytest.fixture(params=[] if skip_tensorflow_js else tensorflow_js_models)
142131
def any_tensorflow_js_model(request):
143132
return pytest.model_packages[request.param]
144133

145134

135+
# fixture to test with all models that should run in the current environment
146136
@pytest.fixture(params=load_model_packages)
147137
def any_model(request):
148138
return pytest.model_packages[request.param]
149139

150140

151-
@pytest.fixture(params=["unet2d_nuclei_broad_model", "unet2d_fixed_shape"])
141+
# temporary fixture to test not with all, but only a manual selection of models
142+
# (models/functionality should be improved to get rid of this specific model group)
143+
@pytest.fixture(params=[] if skip_torch else ["unet2d_nuclei_broad_model", "unet2d_fixed_shape"])
152144
def unet2d_fixed_shape_or_not(request):
153145
return pytest.model_packages[request.param]

tests/prediction_pipeline/test_prediction_pipeline.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import numpy as np
2-
import pytest
32
import xarray as xr
43
from numpy.testing import assert_array_almost_equal
54

@@ -31,31 +30,21 @@ def _test_prediction_pipeline(model_package, weight_format):
3130
assert_array_almost_equal(out, exp, decimal=4)
3231

3332

34-
@pytest.mark.skipif(pytest.skip_torch, reason="requires torch")
3533
def test_prediction_pipeline_torch(any_torch_model):
3634
_test_prediction_pipeline(any_torch_model, "pytorch_state_dict")
3735

3836

39-
@pytest.mark.skipif(pytest.skip_torch, reason="requires torch")
4037
def test_prediction_pipeline_torchscript(any_torchscript_model):
4138
_test_prediction_pipeline(any_torchscript_model, "pytorch_script")
4239

4340

44-
@pytest.mark.skipif(pytest.skip_onnx, reason="requires onnx")
4541
def test_prediction_pipeline_onnx(any_onnx_model):
4642
_test_prediction_pipeline(any_onnx_model, "onnx")
4743

4844

49-
@pytest.mark.skipif(pytest.skip_tensorflow or pytest.tf_major_version != 1, reason="requires tensorflow 1")
50-
def test_prediction_pipeline_tensorflow(any_tensorflow1_model):
51-
_test_prediction_pipeline(any_tensorflow1_model, "tensorflow_saved_model_bundle")
45+
def test_prediction_pipeline_tensorflow(any_tensorflow_model):
46+
_test_prediction_pipeline(any_tensorflow_model, "tensorflow_saved_model_bundle")
5247

5348

54-
@pytest.mark.skipif(pytest.skip_tensorflow or pytest.tf_major_version != 2, reason="requires tensorflow 2")
55-
def test_prediction_pipeline_tensorflow(any_tensorflow2_model):
56-
_test_prediction_pipeline(any_tensorflow2_model, "tensorflow_saved_model_bundle")
57-
58-
59-
@pytest.mark.skipif(pytest.skip_keras, reason="requires keras")
6049
def test_prediction_pipeline_keras(any_keras_model):
6150
_test_prediction_pipeline(any_keras_model, "keras_hdf5")

tests/resource_io/test_load_rdf.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ def test_load_model(any_model):
5252
assert model
5353

5454

55-
@pytest.mark.skipif(pytest.skip_torch, reason="requires pytorch")
5655
def test_load_model_with_abs_path_source(unet2d_nuclei_broad_model):
5756
from bioimageio.core.resource_io import load_raw_resource_description, load_resource_description
5857

@@ -63,7 +62,6 @@ def test_load_model_with_abs_path_source(unet2d_nuclei_broad_model):
6362
assert model
6463

6564

66-
@pytest.mark.skipif(pytest.skip_torch, reason="requires pytorch")
6765
def test_load_model_with_rel_path_source(unet2d_nuclei_broad_model):
6866
from bioimageio.core.resource_io import load_raw_resource_description, load_resource_description
6967

@@ -74,7 +72,6 @@ def test_load_model_with_rel_path_source(unet2d_nuclei_broad_model):
7472
assert model
7573

7674

77-
@pytest.mark.skipif(pytest.skip_torch, reason="requires pytorch")
7875
def test_load_model_with_abs_str_source(unet2d_nuclei_broad_model):
7976
from bioimageio.core.resource_io import load_raw_resource_description, load_resource_description
8077

@@ -85,7 +82,6 @@ def test_load_model_with_abs_str_source(unet2d_nuclei_broad_model):
8582
assert model
8683

8784

88-
@pytest.mark.skipif(pytest.skip_torch, reason="requires pytorch")
8985
def test_load_model_with_rel_str_source(unet2d_nuclei_broad_model):
9086
from bioimageio.core.resource_io import load_raw_resource_description, load_resource_description
9187

@@ -96,7 +92,6 @@ def test_load_model_with_rel_str_source(unet2d_nuclei_broad_model):
9692
assert model
9793

9894

99-
@pytest.mark.skipif(pytest.skip_torch, reason="requires pytorch")
10095
def test_load_remote_model_with_folders():
10196
from bioimageio.core import load_resource_description, load_raw_resource_description
10297
from bioimageio.core.resource_io import nodes

tests/test_cli.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,15 @@
22

33
import numpy as np
44
import pytest
5+
56
from bioimageio.core import load_resource_description
67

78

8-
@pytest.mark.skipif(pytest.skip_torch, reason="requires torch")
99
def test_cli_test_model(unet2d_nuclei_broad_model):
1010
ret = subprocess.run(["bioimageio", "test-model", unet2d_nuclei_broad_model])
1111
assert ret.returncode == 0
1212

1313

14-
@pytest.mark.skipif(pytest.skip_torch, reason="requires torch")
1514
def test_cli_predict_image(unet2d_nuclei_broad_model, tmp_path):
1615
spec = load_resource_description(unet2d_nuclei_broad_model)
1716
in_path = spec.test_inputs[0]
@@ -23,7 +22,6 @@ def test_cli_predict_image(unet2d_nuclei_broad_model, tmp_path):
2322
assert out_path.exists()
2423

2524

26-
@pytest.mark.skipif(pytest.skip_torch, reason="requires torch")
2725
def test_cli_predict_images(unet2d_nuclei_broad_model, tmp_path):
2826
n_images = 3
2927
shape = (1, 1, 128, 128)
@@ -50,7 +48,6 @@ def test_cli_predict_images(unet2d_nuclei_broad_model, tmp_path):
5048
assert np.load(out_path).shape == expected_shape
5149

5250

53-
@pytest.mark.skipif(pytest.skip_torch, reason="requires torch")
5451
def test_torch_to_torchscript(unet2d_nuclei_broad_model, tmp_path):
5552
out_path = tmp_path.with_suffix(".pt")
5653
ret = subprocess.run(
@@ -60,7 +57,7 @@ def test_torch_to_torchscript(unet2d_nuclei_broad_model, tmp_path):
6057
assert out_path.exists()
6158

6259

63-
@pytest.mark.skipif(pytest.skip_torch or pytest.skip_onnx, reason="requires torch and onnx")
60+
@pytest.mark.skipif(pytest.skip_onnx, reason="requires torch and onnx")
6461
def test_torch_to_onnx(unet2d_nuclei_broad_model, tmp_path):
6562
out_path = tmp_path.with_suffix(".onnx")
6663
ret = subprocess.run(["bioimageio", "convert-torch-weights-to-onnx", str(unet2d_nuclei_broad_model), str(out_path)])

tests/test_prediction.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,18 @@
22

33
import imageio
44
import numpy as np
5-
import pytest
6-
7-
from bioimageio.core import load_resource_description
85
from numpy.testing import assert_array_almost_equal
96

7+
from bioimageio.core import load_resource_description
108
from bioimageio.core.resource_io.nodes import Model
119

1210

13-
@pytest.mark.skipif(pytest.skip_torch, reason="requires torch")
1411
def test_test_model(unet2d_nuclei_broad_model):
1512
from bioimageio.core.prediction import test_model
1613

1714
assert test_model(unet2d_nuclei_broad_model)
1815

1916

20-
@pytest.mark.skipif(pytest.skip_torch, reason="requires torch")
2117
def test_predict_image(unet2d_fixed_shape_or_not, tmpdir):
2218
any_model = unet2d_fixed_shape_or_not # todo: replace 'unet2d_fixed_shape_or_not' with 'any_model'
2319
from bioimageio.core.prediction import predict_image
@@ -37,7 +33,6 @@ def test_predict_image(unet2d_fixed_shape_or_not, tmpdir):
3733
assert_array_almost_equal(res, exp, decimal=4)
3834

3935

40-
@pytest.mark.skipif(pytest.skip_torch, reason="requires torch")
4136
def test_predict_image_with_padding(unet2d_fixed_shape_or_not, tmp_path):
4237
any_model = unet2d_fixed_shape_or_not # todo: replace 'unet2d_fixed_shape_or_not' with 'any_model'
4338
from bioimageio.core.prediction import predict_image
@@ -74,7 +69,6 @@ def check_result():
7469
check_result()
7570

7671

77-
@pytest.mark.skipif(pytest.skip_torch, reason="requires torch")
7872
def test_predict_image_with_tiling(unet2d_nuclei_broad_model, tmp_path):
7973
from bioimageio.core.prediction import predict_image
8074

@@ -104,7 +98,6 @@ def check_result():
10498
check_result()
10599

106100

107-
@pytest.mark.skipif(pytest.skip_torch, reason="requires torch")
108101
def test_predict_images(unet2d_nuclei_broad_model, tmp_path):
109102
from bioimageio.core.prediction import predict_images
110103

tests/weight_converter/torch/test_onnx.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import pytest
33

44

5-
@pytest.mark.skipif(pytest.skip_torch, reason="requires pytorch")
65
def test_onnx_converter_from_torch(any_torch_model, tmp_path):
76
from bioimageio.core.weight_converter.torch.onnx import convert_weights_to_onnx
87

tests/weight_converter/torch/test_torchscript.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import os
2-
import pytest
32

43

5-
@pytest.mark.skipif(pytest.skip_torch, reason="requires pytorch")
64
def test_torchscript_converter(any_torch_model, tmp_path):
75
from bioimageio.core.weight_converter.torch import convert_weights_to_pytorch_script
86

0 commit comments

Comments
 (0)