Skip to content

Commit cc5a28b

Browse files
authored
Merge branch 'dev_1.15.1' into development_issue_2234
2 parents 0e27b3d + 509e223 commit cc5a28b

13 files changed

+1506
-594
lines changed

art/attacks/evasion/auto_conjugate_gradient.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,9 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
463463

464464
# self.eta = np.full((self.batch_size, 1, 1, 1), 2 * self.eps_step).astype(ART_NUMPY_DTYPE)
465465
_batch_size = x_k.shape[0]
466-
eta = np.full((_batch_size, 1, 1, 1), self.eps_step).astype(ART_NUMPY_DTYPE)
466+
eta = np.full((_batch_size,) + (1,) * len(self.estimator.input_shape), self.eps_step).astype(
467+
ART_NUMPY_DTYPE
468+
)
467469
self.count_condition_1 = np.zeros(shape=(_batch_size,))
468470
gradk_1 = np.zeros_like(x_k)
469471
cgradk_1 = np.zeros_like(x_k)
@@ -650,4 +652,4 @@ def get_beta(gradk, gradk_1, cgradk_1):
650652
betak = -(_gradk * delta_gradk).sum(axis=1) / (
651653
(_cgradk_1 * delta_gradk).sum(axis=1) + np.finfo(ART_NUMPY_DTYPE).eps
652654
)
653-
return betak.reshape((_batch_size, 1, 1, 1))
655+
return betak.reshape((_batch_size,) + (1,) * (len(gradk.shape) - 1))

art/attacks/evasion/auto_projected_gradient_descent.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,9 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
458458

459459
# modification for image-wise stepsize update
460460
_batch_size = x_k.shape[0]
461-
eta = np.full((_batch_size, 1, 1, 1), self.eps_step).astype(ART_NUMPY_DTYPE)
461+
eta = np.full((_batch_size,) + (1,) * len(self.estimator.input_shape), self.eps_step).astype(
462+
ART_NUMPY_DTYPE
463+
)
462464
self.count_condition_1 = np.zeros(shape=(_batch_size,))
463465

464466
for k_iter in trange(self.max_iter, desc="AutoPGD - iteration", leave=False, disable=not self.verbose):

art/attacks/poisoning/perturbations/image_perturbations.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from typing import Optional, Tuple
2222

2323
import numpy as np
24-
from PIL import Image
2524

2625

2726
def add_single_bd(x: np.ndarray, distance: int = 2, pixel_value: int = 1) -> np.ndarray:
@@ -112,6 +111,8 @@ def insert_image(
112111
:param blend: The blending factor
113112
:return: Backdoored image.
114113
"""
114+
from PIL import Image
115+
115116
n_dim = len(x.shape)
116117
if n_dim == 4:
117118
return np.array(

art/estimators/object_detection/pytorch_object_detector.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,10 @@ def _preprocess_and_convert_inputs(
219219

220220
# Set gradients
221221
if not no_grad:
222-
x_tensor.requires_grad = True
222+
if x_tensor.is_leaf:
223+
x_tensor.requires_grad = True
224+
else:
225+
x_tensor.retain_grad()
223226

224227
# Apply framework-specific preprocessing
225228
x_preprocessed, y_preprocessed = self._apply_preprocessing(x=x_tensor, y=y_tensor, fit=fit, no_grad=no_grad)

art/visualization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from typing import List, Optional, TYPE_CHECKING
2626

2727
import numpy as np
28-
from PIL import Image
2928

3029
from art import config
3130

@@ -97,6 +96,8 @@ def save_image(image_array: np.ndarray, f_name: str) -> None:
9796
:param image_array: Image to be saved.
9897
:param f_name: File name containing extension e.g., my_img.jpg, my_img.png, my_images/my_img.png.
9998
"""
99+
from PIL import Image
100+
100101
file_name = os.path.join(config.ART_DATA_PATH, f_name)
101102
folder = os.path.split(file_name)[0]
102103
if not os.path.exists(folder):

examples/adversarial_training_FBF.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torchvision.transforms as transforms
1212
from torch.utils.data import Dataset, DataLoader
1313

14-
from art.classifiers import PyTorchClassifier
14+
from art.estimators.classification import PyTorchClassifier
1515
from art.data_generators import PyTorchDataGenerator
1616
from art.defences.trainer import AdversarialTrainerFBFPyTorch
1717
from art.utils import load_cifar10

examples/adversarial_training_data_augmentation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
"""
22
This is an example of how to use ART and Keras to perform adversarial training using data generators for CIFAR10
33
"""
4+
import tensorflow as tf
5+
6+
tf.compat.v1.disable_eager_execution()
7+
48
import keras
59
import numpy as np
610
from keras.layers import Conv2D, Dense, Flatten, MaxPooling2D, Input, BatchNormalization

examples/get_started_lightgbm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
# Step 2: Create the model
2929

30-
params = {"objective": "multiclass", "metric": "multi_logloss", "num_class": 10}
30+
params = {"objective": "multiclass", "metric": "multi_logloss", "num_class": 10, "force_col_wise": True}
3131
train_set = lgb.Dataset(x_train, label=np.argmax(y_train, axis=1))
3232
test_set = lgb.Dataset(x_test, label=np.argmax(y_test, axis=1))
3333
model = lgb.train(params=params, train_set=train_set, num_boost_round=100, valid_sets=[test_set])

examples/get_started_xgboost.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
# Step 2: Create the model
2929

30-
params = {"objective": "multi:softprob", "metric": "accuracy", "num_class": 10}
30+
params = {"objective": "multi:softprob", "eval_metric": ["mlogloss", "merror"], "num_class": 10}
3131
dtrain = xgb.DMatrix(x_train, label=np.argmax(y_train, axis=1))
3232
dtest = xgb.DMatrix(x_test, label=np.argmax(y_test, axis=1))
3333
evals = [(dtest, "test"), (dtrain, "train")]

examples/mnist_cnn_fgsm.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
"""Trains a convolutional neural network on the MNIST dataset, then attacks it with the FGSM attack."""
33
from __future__ import absolute_import, division, print_function, unicode_literals
44

5+
import tensorflow as tf
6+
7+
tf.compat.v1.disable_eager_execution()
8+
59
from keras.models import Sequential
610
from keras.layers import Dense, Flatten, Conv2D, MaxPooling2D, Dropout
711
import numpy as np
@@ -35,12 +39,16 @@
3539
acc = np.sum(preds == np.argmax(y_test, axis=1)) / y_test.shape[0]
3640
print("\nTest accuracy: %.2f%%" % (acc * 100))
3741

38-
# Craft adversarial samples with FGSM
39-
epsilon = 0.1 # Maximum perturbation
40-
adv_crafter = FastGradientMethod(classifier, eps=epsilon)
41-
x_test_adv = adv_crafter.generate(x=x_test)
42+
# Define epsilon values
43+
epsilon_values = [0.01, 0.1, 0.15, 0.2, 0.25, 0.3]
4244

43-
# Evaluate the classifier on the adversarial examples
44-
preds = np.argmax(classifier.predict(x_test_adv), axis=1)
45-
acc = np.sum(preds == np.argmax(y_test, axis=1)) / y_test.shape[0]
46-
print("\nTest accuracy on adversarial sample: %.2f%%" % (acc * 100))
45+
# Iterate over epsilon values
46+
for epsilon in epsilon_values:
47+
# Craft adversarial samples with FGSM
48+
adv_crafter = FastGradientMethod(classifier, eps=epsilon)
49+
x_test_adv = adv_crafter.generate(x=x_test, y=y_test)
50+
51+
# Evaluate the classifier on the adversarial examples
52+
preds = np.argmax(classifier.predict(x_test_adv), axis=1)
53+
acc = np.sum(preds == np.argmax(y_test, axis=1)) / y_test.shape[0]
54+
print("Test accuracy on adversarial sample (epsilon = %.2f): %.2f%%" % (epsilon, acc * 100))

0 commit comments

Comments
 (0)