Skip to content

Commit ba01c04

Browse files
authored
Merge pull request #2502 from Trusted-AI/dev_1.18.2
Update to ART 1.18.2
2 parents 1207d0a + 8738a5a commit ba01c04

File tree

3 files changed

+8
-5
lines changed

3 files changed

+8
-5
lines changed

art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import logging
2727
import math
28+
from packaging.version import parse
2829
from typing import Any, TYPE_CHECKING
2930

3031
import numpy as np
@@ -121,8 +122,8 @@ def __init__(
121122
import torch
122123
import torchvision
123124

124-
torch_version = list(map(int, torch.__version__.lower().split("+", maxsplit=1)[0].split(".")))
125-
torchvision_version = list(map(int, torchvision.__version__.lower().split("+", maxsplit=1)[0].split(".")))
125+
torch_version = list(parse(torch.__version__.lower()).release)
126+
torchvision_version = list(parse(torchvision.__version__.lower()).release)
126127
assert (
127128
torch_version[0] >= 1 and torch_version[1] >= 7 or (torch_version[0] >= 2)
128129
), "AdversarialPatchPyTorch requires torch>=1.7.0"

art/attacks/evasion/pixel_threshold.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import logging
2929
from itertools import product
30+
from packaging.version import parse
3031
from typing import TYPE_CHECKING
3132

3233
import numpy as np
@@ -42,7 +43,7 @@
4243
import scipy
4344
from scipy._lib._util import check_random_state
4445

45-
scipy_version = list(map(int, scipy.__version__.lower().split(".")))
46+
scipy_version = list(parse(scipy.__version__.lower()).release)
4647
if scipy_version[1] >= 8:
4748
from scipy.optimize._optimize import _status_message
4849
else:

art/estimators/object_detection/pytorch_object_detector.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from __future__ import annotations
2222

2323
import logging
24+
from packaging.version import parse
2425
from typing import Any, TYPE_CHECKING
2526

2627
import numpy as np
@@ -96,8 +97,8 @@ def __init__(
9697
import torch
9798
import torchvision
9899

99-
torch_version = list(map(int, torch.__version__.lower().split("+", maxsplit=1)[0].split(".")))
100-
torchvision_version = list(map(int, torchvision.__version__.lower().split("+", maxsplit=1)[0].split(".")))
100+
torch_version = list(parse(torch.__version__.lower()).release)
101+
torchvision_version = list(parse(torchvision.__version__.lower()).release)
101102
assert not (torch_version[0] == 1 and (torch_version[1] == 8 or torch_version[1] == 9)), (
102103
"PyTorchObjectDetector does not support torch==1.8 and torch==1.9 because of "
103104
"https://github.com/pytorch/vision/issues/4153. Support will return for torch==1.10."

0 commit comments

Comments
 (0)