Skip to content

Commit 45db2b0

Browse files
committed
unittest universal_pert and pytorch
1 parent 5501bba commit 45db2b0

File tree

2 files changed

+100
-28
lines changed

2 files changed

+100
-28
lines changed

art/attacks/newtonfool_unittest.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -170,31 +170,3 @@ def test_ptclassifier(self):
170170
unittest.main()
171171

172172

173-
174-
175-
176-
177-
178-
179-
180-
181-
182-
183-
184-
185-
186-
187-
188-
189-
190-
191-
192-
193-
194-
195-
196-
197-
198-
199-
200-

art/attacks/universal_perturbation_unittest.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,33 @@
77
import keras.backend as k
88
from keras.models import Sequential
99
from keras.layers import Dense, Flatten, Conv2D, MaxPooling2D
10+
import torch.nn as nn
11+
import torch.nn.functional as F
12+
import torch.optim as optim
1013

1114
from art.attacks.universal_perturbation import UniversalPerturbation
1215
from art.classifiers.tensorflow import TFClassifier
1316
from art.classifiers.keras import KerasClassifier
17+
from art.classifiers.pytorch import PyTorchClassifier
1418
from art.utils import load_mnist
1519

1620

21+
class Model(nn.Module):
22+
def __init__(self):
23+
super(Model, self).__init__()
24+
self.conv = nn.Conv2d(1, 16, 5)
25+
self.pool = nn.MaxPool2d(2, 2)
26+
self.fc = nn.Linear(2304, 10)
27+
28+
def forward(self, x):
29+
x = self.pool(F.relu(self.conv(x)))
30+
x = x.view(-1, 2304)
31+
logit_output = self.fc(x)
32+
output = F.softmax(logit_output, dim=1)
33+
34+
return (logit_output, output)
35+
36+
1737
class TestUniversalPerturbation(unittest.TestCase):
1838
"""
1939
A unittest class for testing the UniversalPerturbation attack.
@@ -115,8 +135,88 @@ def test_krclassifier(self):
115135
self.assertFalse((np.argmax(y_test, axis=1) == test_y_pred).all())
116136
self.assertFalse((np.argmax(y_train, axis=1) == train_y_pred).all())
117137

138+
def test_ptclassifier(self):
139+
"""
140+
Third test with the PyTorchClassifier.
141+
:return:
142+
"""
143+
# Get MNIST
144+
batch_size, nb_train, nb_test = 100, 1000, 10
145+
(x_train, y_train), (x_test, y_test), _, _ = load_mnist()
146+
x_train, y_train = x_train[:nb_train], np.argmax(y_train[:nb_train], axis=1)
147+
x_test, y_test = x_test[:nb_test], np.argmax(y_test[:nb_test], axis=1)
148+
x_train = np.swapaxes(x_train, 1, 3)
149+
x_test = np.swapaxes(x_test, 1, 3)
150+
151+
# Create simple CNN
152+
# Define the network
153+
model = Model()
154+
155+
# Define a loss function and optimizer
156+
loss_fn = nn.CrossEntropyLoss()
157+
optimizer = optim.Adam(model.parameters(), lr=0.01)
158+
159+
# Get classifier
160+
ptc = PyTorchClassifier((0, 1), model, loss_fn, optimizer, (1, 28, 28), (10,))
161+
ptc.fit(x_train, y_train, batch_size=batch_size, nb_epochs=1)
162+
163+
# Attack
164+
# TODO Launch with all possible attacks
165+
attack_params = {"attacker": "newtonfool", "attacker_params": {"max_iter": 20}}
166+
up = UniversalPerturbation(ptc)
167+
x_train_adv = up.generate(x_train, **attack_params)
168+
self.assertTrue((up.fooling_rate >= 0.2) or not up.converged)
169+
170+
x_test_adv = x_test + up.v
171+
self.assertFalse((x_test == x_test_adv).all())
172+
173+
train_y_pred = np.argmax(ptc.predict(x_train_adv), axis=1)
174+
test_y_pred = np.argmax(ptc.predict(x_test_adv), axis=1)
175+
self.assertFalse((y_test == test_y_pred).all())
176+
self.assertFalse((y_train == train_y_pred).all())
177+
118178

119179
if __name__ == '__main__':
120180
unittest.main()
121181

122182

183+
184+
185+
186+
187+
188+
189+
190+
191+
192+
193+
194+
195+
196+
197+
198+
199+
200+
201+
202+
203+
204+
205+
206+
207+
208+
209+
210+
211+
212+
213+
214+
215+
216+
217+
218+
219+
220+
221+
222+

0 commit comments

Comments
 (0)