Skip to content

Commit a432049

Browse files
Irina NicolaeIrina Nicolae
authored andcommitted
Merge from dev
2 parents d9fa474 + 2efb8e0 commit a432049

File tree

10 files changed

+116
-41
lines changed

10 files changed

+116
-41
lines changed

art/attacks/universal_perturbation.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def generate(self, x, **kwargs):
8181
# Instantiate the middle attacker and get the predicted labels
8282
attacker = self._get_attack(self.attacker, self.attacker_params)
8383
pred_y = self.classifier.predict(x, logits=False)
84+
pred_y_max = np.argmax(pred_y, axis=1)
8485

8586
# Start to generate the adversarial examples
8687
nb_iter = 0
@@ -92,14 +93,14 @@ def generate(self, x, **kwargs):
9293
for j, ex in enumerate(x[rnd_idx]):
9394
xi = ex[None, ...]
9495

95-
f_xi = self.classifier.predict(xi + v, logits=False)
96+
f_xi = self.classifier.predict(xi + v, logits=True)
9697
fk_i_hat = np.argmax(f_xi[0])
9798
fk_hat = np.argmax(pred_y[rnd_idx][j])
9899

99100
if fk_i_hat == fk_hat:
100101
# Compute adversarial perturbation
101102
adv_xi = attacker.generate(xi + v)
102-
adv_f_xi = self.classifier.predict(adv_xi, logits=False)
103+
adv_f_xi = self.classifier.predict(adv_xi, logits=True)
103104
adv_fk_i_hat = np.argmax(adv_f_xi[0])
104105

105106
# If the class has changed, update v
@@ -112,10 +113,8 @@ def generate(self, x, **kwargs):
112113

113114
# Compute the error rate
114115
adv_x = x + v
115-
adv_y = self.classifier.predict(adv_x, logits=False)
116-
adv_y_max = np.argmax(adv_y, axis=1)
117-
pred_y_max = np.argmax(pred_y, axis=1)
118-
fooling_rate = np.sum(pred_y_max != adv_y_max) / float(nb_instances)
116+
adv_y = np.argmax(self.classifier.predict(adv_x, logits=False))
117+
fooling_rate = np.sum(pred_y_max != adv_y) / nb_instances
119118

120119
self.fooling_rate = fooling_rate
121120
self.converged = (nb_iter < self.max_iter)
@@ -213,4 +212,3 @@ def _get_class(self, class_name):
213212
class_module = getattr(module, sub_mods[-1])
214213

215214
return class_module
216-

art/attacks/virtual_adversarial.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,23 @@ def generate(self, x, **kwargs):
5656

5757
for ind, val in enumerate(x_adv):
5858
d = np.random.randn(*dims)
59-
e = np.random.randn(*dims)
59+
6060
for _ in range(self.max_iter):
61-
d = self.finite_diff * self._normalize(d)
62-
e = self.finite_diff * self._normalize(e)
63-
preds_new = self.classifier.predict(np.stack((val + d, val + e)))
64-
65-
# Compute KL divergence between logits
61+
d = self._normalize(d)
62+
preds_new = self.classifier.predict((val + d)[None, ...], logits=False)
63+
6664
from scipy.stats import entropy
6765
kl_div1 = entropy(preds[ind], preds_new[0])
68-
kl_div2 = entropy(preds[ind], preds_new[1])
69-
d = (kl_div1 - kl_div2) / np.abs(d - e)
66+
67+
# TODO remove for loop
68+
d_new = d
69+
for i in range(*dims):
70+
d[i] += self.finite_diff
71+
preds_new = self.classifier.predict((val + d)[None, ...], logits=False)
72+
kl_div2 = entropy(preds[ind], preds_new[0])
73+
d_new[i] = (kl_div2-kl_div1)/self.finite_diff
74+
d[i] -= self.finite_diff
75+
d = d_new
7076

7177
# Apply perturbation and clip
7278
val = np.clip(val + self.eps * self._normalize(d), clip_min, clip_max)
@@ -88,7 +94,6 @@ def _normalize(x):
8894
dims = x.shape
8995

9096
x = x.flatten()
91-
x /= np.max(np.abs(x)) + tol
9297
inverse = (np.sum(x**2) + np.sqrt(tol)) ** -.5
9398
x = x * inverse
9499
x = np.reshape(x, dims)

art/classifiers/classifier.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import abc
44
import sys
55

6-
# TODO Add tests for defences on classifier
7-
86
# Ensure compatibility with Python 2 and 3 when using ABCMeta
97
if sys.version_info >= (3, 4):
108
ABC = abc.ABC
@@ -25,7 +23,6 @@ def __init__(self, clip_values, defences=None):
2523
:type clip_values: `tuple`
2624
"""
2725
self._clip_values = clip_values
28-
self._parse_defences(defences)
2926

3027
def predict(self, inputs, logits=False):
3128
"""

art/classifiers/keras.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ def __init__(self, clip_values, model, use_logits=False, defences=None):
2020
:type model: `keras.models.Sequential`
2121
:param use_logits: True if the output of the model are the logits
2222
:type use_logits: `bool`
23-
:param defences: Defences to be activated with the classifier.
24-
:type defences: `str` or `list(str)`
2523
"""
2624
import keras.backend as k
2725

art/classifiers/tensorflow.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(self, clip_values, input_ph, logits, output_ph=None, train=None, lo
3838
"""
3939
import tensorflow as tf
4040

41-
super(TFClassifier, self).__init__(clip_values)
41+
super(TFClassifier, self).__init__(clip_values, defences)
4242
self._nb_classes = int(logits.get_shape()[-1])
4343
self._input_shape = tuple(input_ph.get_shape()[1:])
4444
self._input_ph = input_ph
@@ -178,4 +178,3 @@ def loss_gradient(self, inputs, labels):
178178
grds = self._sess.run(self._loss_grads, feed_dict={self._input_ph: inputs, self._output_ph: labels})
179179

180180
return grds
181-

art/defences/preprocessor.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,45 +30,41 @@ def is_fitted(self):
3030
:return: `True` if the preprocessing model has been fitted (if this applies).
3131
:rtype: `bool`
3232
"""
33-
return self._input_shape
33+
return self._is_fitted
3434

3535
@abc.abstractmethod
3636
def __call__(self, x, y=None):
3737
"""
3838
Perform data preprocessing and return preprocessed data as tuple.
3939
40-
:param x: (np.ndarray) Dataset to be preprocessed
41-
:param y: (np.ndarray) Labels to be preprocessed
40+
:param x: Dataset to be preprocessed.
41+
:type x: `np.ndarray`
42+
:param y: Labels to be preprocessed.
43+
:type y: `np.ndarray`
4244
:return: Preprocessed data
4345
"""
44-
pass
46+
raise NotImplementedError
4547

4648
@abc.abstractmethod
4749
def fit(self, x, y=None, **kwargs):
4850
"""
4951
Fit the parameters of the data preprocessor if it has any.
5052
51-
:param x: (np.ndarray) Training set to fit the preprocessor
52-
:param y: (np.ndarray) Labels for the training set
53-
:param kwargs: (dict) Other parameters
53+
:param x: Training set to fit the preprocessor.
54+
:type x: `np.ndarray`
55+
:param y: Labels for the training set.
56+
:type y: `np.ndarray`
57+
:param kwargs: Other parameters.
58+
:type kwargs: `dict`
5459
:return: None
5560
"""
5661
self._is_fitted = True
5762

58-
def predict(self, x, y=None):
59-
"""
60-
Perform data preprocessing and return preprocessed data as tuple.
61-
62-
:param x: (np.ndarray) Dataset to be preprocessed
63-
:param y: (np.ndarray) Labels to be preprocessed
64-
:return: Preprocessed data
65-
"""
66-
return self.__call__(x, y)
67-
6863
def set_params(self, **kwargs):
6964
"""
7065
Take in a dictionary of parameters and apply checks before saving them as attributes.
71-
:return: True when parsing was successful
66+
67+
:return: `True` when parsing was successful
7268
"""
7369
for key, value in kwargs.items():
7470
if key in self.params:

art/detection/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
"""
2+
Module providing methods for detecting adversarial samples under a common interface.
3+
"""
4+
from art.detection.detector import Detector

art/detection/detector.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from __future__ import absolute_import, division, print_function, unicode_literals
2+
3+
import abc
4+
import sys
5+
6+
7+
# Ensure compatibility with Python 2 and 3 when using ABCMeta
8+
if sys.version_info >= (3, 4):
9+
ABC = abc.ABC
10+
else:
11+
ABC = abc.ABCMeta(str('ABC'), (), {})
12+
13+
14+
class Detector(ABC):
15+
"""
16+
Base abstract class for all detection methods.
17+
"""
18+
def __init__(self):
19+
"""
20+
Create a detector.
21+
"""
22+
self._is_fitted = False
23+
24+
@property
25+
def is_fitted(self):
26+
"""
27+
Return the state of the detector.
28+
29+
:return: `True` if the detection model has been fitted (if this applies).
30+
:rtype: `bool`
31+
"""
32+
return self._is_fitted
33+
34+
@abc.abstractmethod
35+
def fit(self, x, y=None, **kwargs):
36+
"""
37+
Fit the detector using training data (if this applies).
38+
39+
:param x: Training set to fit the detector.
40+
:type x: `np.ndarray`
41+
:param y: Labels for the training set.
42+
:type y: `np.ndarray`
43+
:param kwargs: Other parameters.
44+
:type kwargs: `dict`
45+
:return: None
46+
"""
47+
self._is_fitted = True
48+
49+
@abc.abstractmethod
50+
def __call__(self, x):
51+
"""
52+
Perform detection of adversarial data and return preprocessed data as tuple.
53+
54+
:param x: Data sample on which to perform detection.
55+
:type x: `np.ndarray`
56+
:return: Per-sample prediction whether data is adversarial or not, where `0` means non-adversarial.
57+
Return variable has the same `batch_size` (first dimension) as `x`.
58+
:rtype: `np.ndarray`
59+
"""
60+
raise NotImplementedError
61+
62+
def set_params(self, **kwargs):
63+
"""
64+
Take in a dictionary of parameters and apply checks before saving them as attributes.
65+
:return: True when parsing was successful
66+
"""
67+
for key, value in kwargs.items():
68+
if key in self.params:
69+
setattr(self, key, value)
70+
return True

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ The following defense methods are also supported:
4949
modules/attacks
5050
modules/classifiers
5151
modules/defences
52+
modules/detection
5253
modules/metrics
5354
modules/utils
5455

docs/modules/detection.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
:mod:`art.detection`
2+
===================
3+
4+
Base Class
5+
----------
6+
.. autoclass:: Detector
7+
:members:

0 commit comments

Comments
 (0)