Skip to content

Commit db61a08

Browse files
authored
3725 fixes download test (#3728)
* fixes test Signed-off-by: Wenqi Li <[email protected]> * skip when downloading fails Signed-off-by: Wenqi Li <[email protected]>
1 parent b1a96c5 commit db61a08

9 files changed

+48
-110
lines changed

tests/test_cross_validation.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,11 @@
1111

1212
import os
1313
import unittest
14-
from urllib.error import ContentTooShortError, HTTPError
1514

1615
from monai.apps import CrossValidation, DecathlonDataset
1716
from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord
1817
from monai.utils.enums import PostFix
19-
from tests.utils import skip_if_quick
18+
from tests.utils import skip_if_downloading_fails, skip_if_quick
2019

2120

2221
class TestCrossValidation(unittest.TestCase):
@@ -51,14 +50,8 @@ def _test_dataset(dataset):
5150
download=True,
5251
)
5352

54-
try: # will start downloading if testing_dir doesn't have the Decathlon files
53+
with skip_if_downloading_fails():
5554
data = cvdataset.get_dataset(folds=0)
56-
except (ContentTooShortError, HTTPError, RuntimeError) as e:
57-
print(str(e))
58-
if isinstance(e, RuntimeError):
59-
# FIXME: skip MD5 check as current downloading method may fail
60-
self.assertTrue(str(e).startswith("md5 check"))
61-
return # skipping this test due the network connection errors
6255

6356
_test_dataset(data)
6457

tests/test_decathlondataset.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,11 @@
1313
import shutil
1414
import unittest
1515
from pathlib import Path
16-
from urllib.error import ContentTooShortError, HTTPError
1716

1817
from monai.apps import DecathlonDataset
1918
from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord
2019
from monai.utils.enums import PostFix
21-
from tests.utils import skip_if_quick
20+
from tests.utils import skip_if_downloading_fails, skip_if_quick
2221

2322

2423
class TestDecathlonDataset(unittest.TestCase):
@@ -41,7 +40,7 @@ def _test_dataset(dataset):
4140
self.assertTrue(PostFix.meta("image") in dataset[0])
4241
self.assertTupleEqual(dataset[0]["image"].shape, (1, 36, 47, 44))
4342

44-
try: # will start downloading if testing_dir doesn't have the Decathlon files
43+
with skip_if_downloading_fails():
4544
data = DecathlonDataset(
4645
root_dir=testing_dir,
4746
task="Task04_Hippocampus",
@@ -50,12 +49,6 @@ def _test_dataset(dataset):
5049
download=True,
5150
copy_cache=False,
5251
)
53-
except (ContentTooShortError, HTTPError, RuntimeError) as e:
54-
print(str(e))
55-
if isinstance(e, RuntimeError):
56-
# FIXME: skip MD5 check as current downloading method may fail
57-
self.assertTrue(str(e).startswith("md5 check"))
58-
return # skipping this test due the network connection errors
5952

6053
_test_dataset(data)
6154
data = DecathlonDataset(

tests/test_download_and_extract.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from urllib.error import ContentTooShortError, HTTPError
1717

1818
from monai.apps import download_and_extract, download_url, extractall
19-
from tests.utils import skip_if_quick
19+
from tests.utils import skip_if_downloading_fails, skip_if_quick
2020

2121

2222
class TestDownloadAndExtract(unittest.TestCase):
@@ -27,22 +27,15 @@ def test_actions(self):
2727
filepath = Path(testing_dir) / "MedNIST.tar.gz"
2828
output_dir = Path(testing_dir)
2929
md5_value = "0bc7306e7427e00ad1c5526a6677552d"
30-
try:
30+
with skip_if_downloading_fails():
3131
download_and_extract(url, filepath, output_dir, md5_value)
3232
download_and_extract(url, filepath, output_dir, md5_value)
33-
except (ContentTooShortError, HTTPError, RuntimeError) as e:
34-
print(str(e))
35-
if isinstance(e, RuntimeError):
36-
# FIXME: skip MD5 check as current downloading method may fail
37-
self.assertTrue(str(e).startswith("md5 check"))
38-
return # skipping this test due the network connection errors
3933

4034
wrong_md5 = "0"
4135
with self.assertLogs(logger="monai.apps", level="ERROR"):
4236
try:
4337
download_url(url, filepath, wrong_md5)
4438
except (ContentTooShortError, HTTPError, RuntimeError) as e:
45-
print(str(e))
4639
if isinstance(e, RuntimeError):
4740
# FIXME: skip MD5 check as current downloading method may fail
4841
self.assertTrue(str(e).startswith("md5 check"))
@@ -56,7 +49,7 @@ def test_actions(self):
5649
@skip_if_quick
5750
def test_default(self):
5851
with tempfile.TemporaryDirectory() as tmp_dir:
59-
try:
52+
with skip_if_downloading_fails():
6053
# icon.tar.gz https://drive.google.com/file/d/1HrQd-AKPbts9jkTNN4pT8vLZyhM5irVn/view?usp=sharing
6154
download_and_extract(
6255
"https://drive.google.com/uc?id=1HrQd-AKPbts9jkTNN4pT8vLZyhM5irVn",
@@ -71,12 +64,6 @@ def test_default(self):
7164
hash_val="ac6e167ee40803577d98237f2b0241e5",
7265
file_type="zip",
7366
)
74-
except (ContentTooShortError, HTTPError, RuntimeError) as e:
75-
print(str(e))
76-
if isinstance(e, RuntimeError):
77-
# FIXME: skip MD5 check as current downloading method may fail
78-
self.assertTrue(str(e).startswith("md5 check"))
79-
return # skipping this test due the network connection errors
8067

8168

8269
if __name__ == "__main__":

tests/test_efficientnet.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import unittest
1414
from typing import TYPE_CHECKING
1515
from unittest import skipUnless
16-
from urllib.error import ContentTooShortError, HTTPError
1716

1817
import torch
1918
from parameterized import parameterized
@@ -27,7 +26,7 @@
2726
get_efficientnet_image_size,
2827
)
2928
from monai.utils import optional_import
30-
from tests.utils import skip_if_quick, test_pretrained_networks, test_script_save
29+
from tests.utils import skip_if_downloading_fails, skip_if_quick, test_pretrained_networks, test_script_save
3130

3231
if TYPE_CHECKING:
3332
import torchvision
@@ -251,12 +250,8 @@ class TestEFFICIENTNET(unittest.TestCase):
251250
def test_shape(self, input_param, input_shape, expected_shape):
252251
device = "cuda" if torch.cuda.is_available() else "cpu"
253252

254-
try:
255-
# initialize model
253+
with skip_if_downloading_fails():
256254
net = EfficientNetBN(**input_param).to(device)
257-
except (ContentTooShortError, HTTPError, RuntimeError) as e:
258-
print(str(e))
259-
return # skipping the tests because of http errors
260255

261256
# run inference with random tensor
262257
with eval_mode(net):
@@ -269,12 +264,8 @@ def test_shape(self, input_param, input_shape, expected_shape):
269264
def test_non_default_shapes(self, input_param, input_shape, expected_shape):
270265
device = "cuda" if torch.cuda.is_available() else "cpu"
271266

272-
try:
273-
# initialize model
267+
with skip_if_downloading_fails():
274268
net = EfficientNetBN(**input_param).to(device)
275-
except (ContentTooShortError, HTTPError, RuntimeError) as e:
276-
print(str(e))
277-
return # skipping the tests because of http errors
278269

279270
# override input shape with different variations
280271
num_dims = len(input_shape) - 2
@@ -387,12 +378,8 @@ class TestExtractFeatures(unittest.TestCase):
387378
def test_shape(self, input_param, input_shape, expected_shapes):
388379
device = "cuda" if torch.cuda.is_available() else "cpu"
389380

390-
try:
391-
# initialize model
381+
with skip_if_downloading_fails():
392382
net = EfficientNetBNFeatures(**input_param).to(device)
393-
except (ContentTooShortError, HTTPError, RuntimeError) as e:
394-
print(str(e))
395-
return # skipping the tests because of http errors
396383

397384
# run inference with random tensor
398385
with eval_mode(net):

tests/test_integration_classification_2d.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import os
1313
import unittest
1414
import warnings
15-
from urllib.error import ContentTooShortError, HTTPError
1615

1716
import numpy as np
1817
import torch
@@ -39,7 +38,7 @@
3938
)
4039
from monai.utils import set_determinism
4140
from tests.testing_data.integration_answers import test_integration_value
42-
from tests.utils import DistTestCase, TimedCall, skip_if_quick
41+
from tests.utils import DistTestCase, TimedCall, skip_if_downloading_fails, skip_if_quick
4342

4443
TEST_DATA_URL = "https://drive.google.com/uc?id=1QsnnkvZyJPcbRoV_ArW8SnE1OTuoVbKE"
4544
MD5_VALUE = "0bc7306e7427e00ad1c5526a6677552d"
@@ -186,14 +185,8 @@ def setUp(self):
186185
dataset_file = os.path.join(self.data_dir, "MedNIST.tar.gz")
187186

188187
if not os.path.exists(data_dir):
189-
try:
188+
with skip_if_downloading_fails():
190189
download_and_extract(TEST_DATA_URL, dataset_file, self.data_dir, MD5_VALUE)
191-
except (ContentTooShortError, HTTPError, RuntimeError) as e:
192-
print(str(e))
193-
if isinstance(e, RuntimeError):
194-
# FIXME: skip MD5 check as current downloading method may fail
195-
self.assertTrue(str(e).startswith("md5 check"))
196-
return # skipping this test due the network connection errors
197190

198191
assert os.path.exists(data_dir)
199192

tests/test_lr_finder.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from monai.optimizers import LearningRateFinder
2525
from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord
2626
from monai.utils import optional_import, set_determinism
27+
from tests.utils import skip_if_downloading_fails
2728

2829
if TYPE_CHECKING:
2930
import matplotlib.pyplot as plt
@@ -61,14 +62,15 @@ def setUp(self):
6162

6263
def test_lr_finder(self):
6364
# 0.001 gives 54 examples
64-
train_ds = MedNISTDataset(
65-
root_dir=self.root_dir,
66-
transform=self.transforms,
67-
section="validation",
68-
val_frac=0.001,
69-
download=True,
70-
num_workers=10,
71-
)
65+
with skip_if_downloading_fails():
66+
train_ds = MedNISTDataset(
67+
root_dir=self.root_dir,
68+
transform=self.transforms,
69+
section="validation",
70+
val_frac=0.001,
71+
download=True,
72+
num_workers=10,
73+
)
7274
train_loader = DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=10)
7375
num_classes = train_ds.get_num_classes()
7476

tests/test_mednistdataset.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,11 @@
1313
import shutil
1414
import unittest
1515
from pathlib import Path
16-
from urllib.error import ContentTooShortError, HTTPError
1716

1817
from monai.apps import MedNISTDataset
1918
from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord
2019
from monai.utils.enums import PostFix
21-
from tests.utils import skip_if_quick
20+
from tests.utils import skip_if_downloading_fails, skip_if_quick
2221

2322
MEDNIST_FULL_DATASET_LENGTH = 58954
2423

@@ -43,16 +42,10 @@ def _test_dataset(dataset):
4342
self.assertTrue(PostFix.meta("image") in dataset[0])
4443
self.assertTupleEqual(dataset[0]["image"].shape, (1, 64, 64))
4544

46-
try: # will start downloading if testing_dir doesn't have the MedNIST files
45+
with skip_if_downloading_fails():
4746
data = MedNISTDataset(
4847
root_dir=testing_dir, transform=transform, section="test", download=True, copy_cache=False
4948
)
50-
except (ContentTooShortError, HTTPError, RuntimeError) as e:
51-
print(str(e))
52-
if isinstance(e, RuntimeError):
53-
# FIXME: skip MD5 check as current downloading method may fail
54-
self.assertTrue(str(e).startswith("md5 check"))
55-
return # skipping this test due the network connection errors
5649

5750
_test_dataset(data)
5851

tests/test_mmar_download.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import tempfile
1414
import unittest
1515
from pathlib import Path
16-
from urllib.error import ContentTooShortError, HTTPError
1716

1817
import numpy as np
1918
import torch
@@ -22,7 +21,7 @@
2221
from monai.apps import RemoteMMARKeys, download_mmar, get_model_spec, load_from_mmar
2322
from monai.apps.mmars import MODEL_DESC
2423
from monai.apps.mmars.mmars import _get_val
25-
from tests.utils import skip_if_quick
24+
from tests.utils import skip_if_downloading_fails, skip_if_quick
2625

2726
TEST_CASES = [["clara_pt_prostate_mri_segmentation_1"], ["clara_pt_covid19_ct_lesion_segmentation_1"]]
2827
TEST_EXTRACT_CASES = [
@@ -105,7 +104,7 @@ class TestMMMARDownload(unittest.TestCase):
105104
@parameterized.expand(TEST_CASES)
106105
@skip_if_quick
107106
def test_download(self, idx):
108-
try:
107+
with skip_if_downloading_fails():
109108
# test model specification
110109
cand = get_model_spec(idx)
111110
self.assertEqual(cand[RemoteMMARKeys.ID], idx)
@@ -116,22 +115,12 @@ def test_download(self, idx):
116115
download_mmar(idx, mmar_dir=tmp_dir, progress=False)
117116
download_mmar(idx, mmar_dir=Path(tmp_dir), progress=False, version=1) # repeated to check caching
118117
self.assertTrue(os.path.exists(os.path.join(tmp_dir, idx)))
119-
except (ContentTooShortError, HTTPError, RuntimeError) as e:
120-
print(str(e))
121-
if isinstance(e, HTTPError):
122-
self.assertTrue("500" in str(e)) # http error has the code 500
123-
return # skipping this test due the network connection errors
124118

125119
@parameterized.expand(TEST_EXTRACT_CASES)
126120
@skip_if_quick
127121
def test_load_ckpt(self, input_args, expected_name, expected_val):
128-
try:
122+
with skip_if_downloading_fails():
129123
output = load_from_mmar(**input_args)
130-
except (ContentTooShortError, HTTPError, RuntimeError) as e:
131-
print(str(e))
132-
if isinstance(e, HTTPError):
133-
self.assertTrue("500" in str(e)) # http error has the code 500
134-
return
135124
self.assertEqual(output.__class__.__name__, expected_name)
136125
x = next(output.parameters()) # verify the first element
137126
np.testing.assert_allclose(x[0][0].detach().cpu().numpy(), expected_val, rtol=1e-3, atol=1e-3)

tests/utils.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@
2121
import traceback
2222
import unittest
2323
import warnings
24+
from contextlib import contextmanager
2425
from functools import partial
2526
from subprocess import PIPE, Popen
2627
from typing import Callable, Optional, Tuple
27-
from urllib.error import ContentTooShortError, HTTPError, URLError
28+
from urllib.error import ContentTooShortError, HTTPError
2829

2930
import numpy as np
3031
import torch
@@ -93,17 +94,25 @@ def assert_allclose(
9394
np.testing.assert_allclose(actual, desired, *args, **kwargs)
9495

9596

96-
def test_pretrained_networks(network, input_param, device):
97+
@contextmanager
98+
def skip_if_downloading_fails():
9799
try:
100+
yield
101+
except (ContentTooShortError, HTTPError, ConnectionError) as e:
102+
raise unittest.SkipTest(f"error while downloading: {e}") from e
103+
except RuntimeError as rt_e:
104+
if "network issue" in str(rt_e):
105+
raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e
106+
if "gdown dependency" in str(rt_e): # no gdown installed
107+
raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e
108+
if "md5 check" in str(rt_e):
109+
raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e
110+
raise rt_e
111+
112+
113+
def test_pretrained_networks(network, input_param, device):
114+
with skip_if_downloading_fails():
98115
return network(**input_param).to(device)
99-
except (URLError, HTTPError) as e:
100-
raise unittest.SkipTest(e) from e
101-
except RuntimeError as r_error:
102-
if "unexpected EOF" in f"{r_error}": # The file might be corrupted.
103-
raise unittest.SkipTest(f"{r_error}") from r_error
104-
if "network issue" in f"{r_error}": # The network is not available.
105-
raise unittest.SkipTest(f"{r_error}") from r_error
106-
raise
107116

108117

109118
def test_is_quick():
@@ -651,16 +660,8 @@ def test_script_save(net, *inputs, device=None, rtol=1e-4, atol=0.0):
651660

652661
def download_url_or_skip_test(*args, **kwargs):
653662
"""``download_url`` and skip the tests if any downloading error occurs."""
654-
try:
663+
with skip_if_downloading_fails():
655664
download_url(*args, **kwargs)
656-
except (ContentTooShortError, HTTPError) as e:
657-
raise unittest.SkipTest(f"error while downloading: {e}") from e
658-
except RuntimeError as rt_e:
659-
if "network issue" in str(rt_e):
660-
raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e
661-
if "gdown dependency" in str(rt_e): # no gdown installed
662-
raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e
663-
raise rt_e
664665

665666

666667
def query_memory(n=2):

0 commit comments

Comments
 (0)