Skip to content

Commit e148cfd

Browse files
wyliericspod
andauthored
3284 torch version check (#3285)
* torch version check Signed-off-by: Wenqi Li <[email protected]> * temp tests Signed-off-by: Wenqi Li <[email protected]> * additional cases Signed-off-by: Wenqi Li <[email protected]> * fixes tests Signed-off-by: Wenqi Li <[email protected]> * update unit test names Signed-off-by: Wenqi Li <[email protected]> * remove temp tests Signed-off-by: Wenqi Li <[email protected]> * update based on comments Signed-off-by: Wenqi Li <[email protected]> * fixes codeformat Signed-off-by: Wenqi Li <[email protected]> Co-authored-by: Eric Kerfoot <[email protected]>
1 parent 9ddc9e6 commit e148cfd

File tree

9 files changed

+120
-28
lines changed

9 files changed

+120
-28
lines changed

monai/engines/trainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from monai.engines.workflow import Workflow
2727
from monai.inferers import Inferer, SimpleInferer
2828
from monai.transforms import Transform
29-
from monai.utils import PT_BEFORE_1_7, min_version, optional_import
29+
from monai.utils import min_version, optional_import, pytorch_after
3030
from monai.utils.enums import CommonKeys as Keys
3131

3232
if TYPE_CHECKING:
@@ -190,7 +190,7 @@ def _compute_pred_loss():
190190

191191
self.network.train()
192192
# `set_to_none` only work from PyTorch 1.7.0
193-
if PT_BEFORE_1_7:
193+
if not pytorch_after(1, 7):
194194
self.optimizer.zero_grad()
195195
else:
196196
self.optimizer.zero_grad(set_to_none=self.optim_set_to_none)
@@ -359,7 +359,7 @@ def _iteration(
359359
d_total_loss = torch.zeros(1)
360360
for _ in range(self.d_train_steps):
361361
# `set_to_none` only work from PyTorch 1.7.0
362-
if PT_BEFORE_1_7:
362+
if not pytorch_after(1, 7):
363363
self.d_optimizer.zero_grad()
364364
else:
365365
self.d_optimizer.zero_grad(set_to_none=self.optim_set_to_none)
@@ -377,7 +377,7 @@ def _iteration(
377377
non_blocking=engine.non_blocking, # type: ignore
378378
)
379379
g_output = self.g_inferer(g_input, self.g_network)
380-
if PT_BEFORE_1_7:
380+
if not pytorch_after(1, 7):
381381
self.g_optimizer.zero_grad()
382382
else:
383383
self.g_optimizer.zero_grad(set_to_none=self.optim_set_to_none)

monai/networks/layers/simplelayers.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,17 @@
2121
from monai.networks.layers.convutils import gaussian_1d
2222
from monai.networks.layers.factories import Conv
2323
from monai.utils import (
24-
PT_BEFORE_1_7,
2524
ChannelMatching,
2625
InvalidPyTorchVersionError,
2726
SkipMode,
2827
look_up_option,
2928
optional_import,
30-
version_leq,
29+
pytorch_after,
3130
)
3231
from monai.utils.misc import issequenceiterable
3332

3433
_C, _ = optional_import("monai._C")
35-
if not PT_BEFORE_1_7:
34+
if pytorch_after(1, 7):
3635
fft, _ = optional_import("torch.fft")
3736

3837
__all__ = [
@@ -295,11 +294,12 @@ def apply_filter(x: torch.Tensor, kernel: torch.Tensor, **kwargs) -> torch.Tenso
295294
x = x.view(1, kernel.shape[0], *spatials)
296295
conv = [F.conv1d, F.conv2d, F.conv3d][n_spatial - 1]
297296
if "padding" not in kwargs:
298-
if version_leq(torch.__version__, "1.10.0b"):
297+
if pytorch_after(1, 10):
298+
kwargs["padding"] = "same"
299+
else:
299300
# even-sized kernels are not supported
300301
kwargs["padding"] = [(k - 1) // 2 for k in kernel.shape[2:]]
301-
else:
302-
kwargs["padding"] = "same"
302+
303303
if "stride" not in kwargs:
304304
kwargs["stride"] = 1
305305
output = conv(x, kernel, groups=kernel.shape[0], bias=None, **kwargs)
@@ -387,7 +387,7 @@ class HilbertTransform(nn.Module):
387387

388388
def __init__(self, axis: int = 2, n: Union[int, None] = None) -> None:
389389

390-
if PT_BEFORE_1_7:
390+
if not pytorch_after(1, 7):
391391
raise InvalidPyTorchVersionError("1.7.0", self.__class__.__name__)
392392

393393
super().__init__()

monai/networks/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from monai.utils.deprecate_utils import deprecated_arg
2424
from monai.utils.misc import ensure_tuple, set_determinism
25-
from monai.utils.module import PT_BEFORE_1_7
25+
from monai.utils.module import pytorch_after
2626

2727
__all__ = [
2828
"one_hot",
@@ -464,7 +464,7 @@ def convert_to_torchscript(
464464
with torch.no_grad():
465465
script_module = torch.jit.script(model)
466466
if filename_or_obj is not None:
467-
if PT_BEFORE_1_7:
467+
if not pytorch_after(1, 7):
468468
torch.jit.save(m=script_module, f=filename_or_obj)
469469
else:
470470
torch.jit.save(m=script_module, f=filename_or_obj, _extra_files=extra_files)

monai/transforms/intensity/array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,14 @@
3030
from monai.transforms.utils import Fourier, equalize_hist, is_positive, rescale_array
3131
from monai.transforms.utils_pytorch_numpy_unification import clip, percentile, where
3232
from monai.utils import (
33-
PT_BEFORE_1_7,
3433
InvalidPyTorchVersionError,
3534
convert_data_type,
3635
convert_to_dst_type,
3736
ensure_tuple,
3837
ensure_tuple_rep,
3938
ensure_tuple_size,
4039
fall_back_tuple,
40+
pytorch_after,
4141
)
4242
from monai.utils.deprecate_utils import deprecated_arg
4343
from monai.utils.enums import TransformBackends
@@ -1072,7 +1072,7 @@ class DetectEnvelope(Transform):
10721072

10731073
def __init__(self, axis: int = 1, n: Union[int, None] = None) -> None:
10741074

1075-
if PT_BEFORE_1_7:
1075+
if not pytorch_after(1, 7):
10761076
raise InvalidPyTorchVersionError("1.7.0", self.__class__.__name__)
10771077

10781078
if axis < 0:

monai/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
zip_with,
5959
)
6060
from .module import (
61-
PT_BEFORE_1_7,
6261
InvalidPyTorchVersionError,
6362
OptionalImportError,
6463
damerau_levenshtein_distance,
@@ -71,6 +70,7 @@
7170
look_up_option,
7271
min_version,
7372
optional_import,
73+
pytorch_after,
7474
require_pkg,
7575
version_leq,
7676
)

monai/utils/module.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
11+
1112
import enum
13+
import os
14+
import re
1215
import sys
1316
import warnings
1417
from functools import wraps
@@ -36,8 +39,8 @@
3639
"get_full_type_name",
3740
"get_package_version",
3841
"get_torch_version_tuple",
39-
"PT_BEFORE_1_7",
4042
"version_leq",
43+
"pytorch_after",
4144
]
4245

4346

@@ -450,7 +453,51 @@ def _try_cast(val: str):
450453
return True
451454

452455

453-
try:
454-
PT_BEFORE_1_7 = torch.__version__ != "1.7.0" and version_leq(torch.__version__, "1.7.0")
455-
except (AttributeError, TypeError):
456-
PT_BEFORE_1_7 = True
456+
def pytorch_after(major, minor, patch=0, current_ver_string=None) -> bool:
457+
"""
458+
Compute whether the current pytorch version is after or equal to the specified version.
459+
The current system pytorch version is determined by `torch.__version__` or
460+
via system environment variable `PYTORCH_VER`.
461+
462+
Args:
463+
major: major version number to be compared with
464+
minor: minor version number to be compared with
465+
patch: patch version number to be compared with
466+
current_ver_string: if None, `torch.__version__` will be used.
467+
468+
Returns:
469+
True if the current pytorch version is greater than or equal to the specified version.
470+
"""
471+
472+
try:
473+
if current_ver_string is None:
474+
_env_var = os.environ.get("PYTORCH_VER", "")
475+
current_ver_string = _env_var if _env_var else torch.__version__
476+
ver, has_ver = optional_import("pkg_resources", name="parse_version")
477+
if has_ver:
478+
return ver(".".join((f"{major}", f"{minor}", f"{patch}"))) <= ver(f"{current_ver_string}") # type: ignore
479+
parts = f"{current_ver_string}".split("+", 1)[0].split(".", 3)
480+
while len(parts) < 3:
481+
parts += ["0"]
482+
c_major, c_minor, c_patch = parts[:3]
483+
except (AttributeError, ValueError, TypeError):
484+
c_major, c_minor = get_torch_version_tuple()
485+
c_patch = "0"
486+
c_mn = int(c_major), int(c_minor)
487+
mn = int(major), int(minor)
488+
if c_mn != mn:
489+
return c_mn > mn
490+
is_prerelease = ("a" in f"{c_patch}".lower()) or ("rc" in f"{c_patch}".lower())
491+
c_p = 0
492+
try:
493+
p_reg = re.search(r"\d+", f"{c_patch}")
494+
if p_reg:
495+
c_p = int(p_reg.group())
496+
except (AttributeError, TypeError, ValueError):
497+
is_prerelease = True
498+
patch = int(patch)
499+
if c_p != patch:
500+
return c_p > patch # type: ignore
501+
if is_prerelease:
502+
return False
503+
return True

tests/test_map_label_value.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from parameterized import parameterized
1717

1818
from monai.transforms import MapLabelValue
19-
from monai.utils import PT_BEFORE_1_7
19+
from monai.utils import pytorch_after
2020
from tests.utils import TEST_NDARRAYS
2121

2222
TESTS = []
@@ -34,7 +34,7 @@
3434
]
3535
)
3636
# PyTorch 1.5.1 doesn't support rich dtypes
37-
if not PT_BEFORE_1_7:
37+
if pytorch_after(1, 7):
3838
TESTS.append(
3939
[
4040
{"orig_labels": [1.5, 2.5, 3.5], "target_labels": [0, 1, 2], "dtype": np.int8},
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2020 - 2021 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
13+
import unittest
14+
15+
from parameterized import parameterized
16+
17+
from monai.utils import pytorch_after
18+
19+
TEST_CASES = (
20+
(1, 5, 9, "1.6.0"),
21+
(1, 6, 0, "1.6.0"),
22+
(1, 6, 1, "1.6.0", False),
23+
(1, 7, 0, "1.6.0", False),
24+
(2, 6, 0, "1.6.0", False),
25+
(0, 6, 0, "1.6.0a0+3fd9dcf"),
26+
(1, 5, 9, "1.6.0a0+3fd9dcf"),
27+
(1, 6, 0, "1.6.0a0+3fd9dcf", False),
28+
(1, 6, 1, "1.6.0a0+3fd9dcf", False),
29+
(2, 6, 0, "1.6.0a0+3fd9dcf", False),
30+
(1, 6, 0, "1.6.0-rc0+3fd9dcf", False), # defaults to prerelease
31+
(1, 6, 0, "1.6.0rc0", False),
32+
(1, 6, 0, "1.6", True),
33+
(1, 6, 0, "1", False),
34+
(1, 6, 0, "1.6.0+cpu", True),
35+
(1, 6, 1, "1.6.0+cpu", False),
36+
)
37+
38+
39+
class TestPytorchVersionCompare(unittest.TestCase):
40+
@parameterized.expand(TEST_CASES)
41+
def test_compare(self, a, b, p, current, expected=True):
42+
"""Test pytorch_after with a and b"""
43+
self.assertEqual(pytorch_after(a, b, p, current), expected)
44+
45+
46+
if __name__ == "__main__":
47+
unittest.main()

tests/utils.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@
3636
from monai.data import create_test_image_2d, create_test_image_3d
3737
from monai.networks import convert_to_torchscript
3838
from monai.utils import optional_import
39-
from monai.utils.misc import is_module_ver_at_least
40-
from monai.utils.module import version_leq
39+
from monai.utils.module import pytorch_after, version_leq
4140
from monai.utils.type_conversion import convert_data_type
4241

4342
nib, _ = optional_import("nibabel")
@@ -193,7 +192,7 @@ class SkipIfBeforePyTorchVersion:
193192

194193
def __init__(self, pytorch_version_tuple):
195194
self.min_version = pytorch_version_tuple
196-
self.version_too_old = not is_module_ver_at_least(torch, pytorch_version_tuple)
195+
self.version_too_old = not pytorch_after(*pytorch_version_tuple)
197196

198197
def __call__(self, obj):
199198
return unittest.skipIf(
@@ -207,8 +206,7 @@ class SkipIfAtLeastPyTorchVersion:
207206

208207
def __init__(self, pytorch_version_tuple):
209208
self.max_version = pytorch_version_tuple
210-
test_ver = ".".join(map(str, self.max_version))
211-
self.version_too_new = version_leq(test_ver, torch.__version__)
209+
self.version_too_new = pytorch_after(*pytorch_version_tuple)
212210

213211
def __call__(self, obj):
214212
return unittest.skipIf(

0 commit comments

Comments
 (0)