Skip to content

Commit 5a213aa

Browse files
authored
Merge pull request #1589 from Trusted-AI/development_issue_1568
Check scipy version for PixelThreshold attack
2 parents d422b7f + 4dd03fb commit 5a213aa

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

art/attacks/evasion/pixel_threshold.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
| One Pixel Attack Paper link: https://arxiv.org/ans/1710.08864
2323
| Pixel and Threshold Attack Paper link: https://arxiv.org/abs/1906.06026
2424
"""
25-
# pylint: disable=C0302
25+
# pylint: disable=C0302,C0413
2626
from __future__ import absolute_import, division, print_function, unicode_literals
2727

2828
import logging
@@ -39,16 +39,22 @@
3939
# Otherwise may use Tensorflow's implementation of DE.
4040

4141
from six import string_types
42+
import scipy
4243
from scipy._lib._util import check_random_state
43-
from scipy.optimize.optimize import _status_message
44-
from scipy.optimize import OptimizeResult, minimize
45-
from tqdm.auto import tqdm
46-
47-
from art.config import ART_NUMPY_DTYPE
48-
from art.attacks.attack import EvasionAttack
49-
from art.estimators.estimator import BaseEstimator, NeuralNetworkMixin
50-
from art.estimators.classification.classifier import ClassifierMixin
51-
from art.utils import check_and_transform_label_format
44+
45+
scipy_version = list(map(int, scipy.__version__.lower().split(".")))
46+
if scipy_version[1] >= 8:
47+
from scipy.optimize._optimize import _status_message # pylint: disable=E0611
48+
else:
49+
from scipy.optimize.optimize import _status_message # pylint: disable=E0611
50+
from scipy.optimize import OptimizeResult, minimize # noqa
51+
from tqdm.auto import tqdm # noqa
52+
53+
from art.config import ART_NUMPY_DTYPE # noqa
54+
from art.attacks.attack import EvasionAttack # noqa
55+
from art.estimators.estimator import BaseEstimator, NeuralNetworkMixin # noqa
56+
from art.estimators.classification.classifier import ClassifierMixin # noqa
57+
from art.utils import check_and_transform_label_format # noqa
5258

5359
if TYPE_CHECKING:
5460
from art.utils import CLASSIFIER_NEURALNETWORK_TYPE

tests/defences/trainer/test_adversarial_trainer_FBF.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def _get_adv_trainer():
3434
trainer = None
3535
if framework == "pytorch":
3636
classifier, _ = image_dl_estimator()
37-
trainer = AdversarialTrainerFBFPyTorch(classifier, eps=0.1)
37+
trainer = AdversarialTrainerFBFPyTorch(classifier, eps=0.05)
3838
if framework == "scikitlearn":
3939
trainer = None
4040

@@ -75,7 +75,7 @@ def test_adversarial_trainer_fbf_pytorch_fit_and_predict(get_adv_trainer, fix_ge
7575
)
7676

7777
assert accuracy == 0.32
78-
assert accuracy_new == 0.58
78+
assert accuracy_new == 0.63
7979

8080
trainer.fit(x_train_mnist, y_train_mnist, nb_epochs=20, validation_data=(x_train_mnist, y_train_mnist))
8181

0 commit comments

Comments
 (0)