|
21 | 21 | import traceback |
22 | 22 | import unittest |
23 | 23 | import warnings |
| 24 | +from contextlib import contextmanager |
24 | 25 | from functools import partial |
25 | 26 | from subprocess import PIPE, Popen |
26 | 27 | from typing import Callable, Optional, Tuple |
27 | | -from urllib.error import ContentTooShortError, HTTPError, URLError |
| 28 | +from urllib.error import ContentTooShortError, HTTPError |
28 | 29 |
|
29 | 30 | import numpy as np |
30 | 31 | import torch |
@@ -93,17 +94,25 @@ def assert_allclose( |
93 | 94 | np.testing.assert_allclose(actual, desired, *args, **kwargs) |
94 | 95 |
|
95 | 96 |
|
96 | | -def test_pretrained_networks(network, input_param, device): |
| 97 | +@contextmanager |
| 98 | +def skip_if_downloading_fails(): |
97 | 99 | try: |
| 100 | + yield |
| 101 | + except (ContentTooShortError, HTTPError, ConnectionError) as e: |
| 102 | + raise unittest.SkipTest(f"error while downloading: {e}") from e |
| 103 | + except RuntimeError as rt_e: |
| 104 | + if "network issue" in str(rt_e): |
| 105 | + raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e |
| 106 | + if "gdown dependency" in str(rt_e): # no gdown installed |
| 107 | + raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e |
| 108 | + if "md5 check" in str(rt_e): |
| 109 | + raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e |
| 110 | + raise rt_e |
| 111 | + |
| 112 | + |
| 113 | +def test_pretrained_networks(network, input_param, device): |
| 114 | + with skip_if_downloading_fails(): |
98 | 115 | return network(**input_param).to(device) |
99 | | - except (URLError, HTTPError) as e: |
100 | | - raise unittest.SkipTest(e) from e |
101 | | - except RuntimeError as r_error: |
102 | | - if "unexpected EOF" in f"{r_error}": # The file might be corrupted. |
103 | | - raise unittest.SkipTest(f"{r_error}") from r_error |
104 | | - if "network issue" in f"{r_error}": # The network is not available. |
105 | | - raise unittest.SkipTest(f"{r_error}") from r_error |
106 | | - raise |
107 | 116 |
|
108 | 117 |
|
109 | 118 | def test_is_quick(): |
@@ -651,16 +660,8 @@ def test_script_save(net, *inputs, device=None, rtol=1e-4, atol=0.0): |
651 | 660 |
|
652 | 661 | def download_url_or_skip_test(*args, **kwargs): |
653 | 662 | """``download_url`` and skip the tests if any downloading error occurs.""" |
654 | | - try: |
| 663 | + with skip_if_downloading_fails(): |
655 | 664 | download_url(*args, **kwargs) |
656 | | - except (ContentTooShortError, HTTPError) as e: |
657 | | - raise unittest.SkipTest(f"error while downloading: {e}") from e |
658 | | - except RuntimeError as rt_e: |
659 | | - if "network issue" in str(rt_e): |
660 | | - raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e |
661 | | - if "gdown dependency" in str(rt_e): # no gdown installed |
662 | | - raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e |
663 | | - raise rt_e |
664 | 665 |
|
665 | 666 |
|
666 | 667 | def query_memory(n=2): |
|
0 commit comments