Skip to content

Commit 6765ffb

Browse files
committed
unittest carlini-pytorch
1 parent 45db2b0 commit 6765ffb

File tree

2 files changed

+121
-38
lines changed

2 files changed

+121
-38
lines changed

art/attacks/carlini_unittest.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,33 @@
77
import keras.backend as k
88
from keras.models import Sequential
99
from keras.layers import Dense, Flatten, Conv2D, MaxPooling2D
10+
import torch.nn as nn
11+
import torch.nn.functional as F
12+
import torch.optim as optim
1013

1114
from art.attacks.carlini import CarliniL2Method
1215
from art.classifiers.tensorflow import TFClassifier
1316
from art.classifiers.keras import KerasClassifier
17+
from art.classifiers.pytorch import PyTorchClassifier
1418
from art.utils import load_mnist, random_targets
1519

1620

21+
class Model(nn.Module):
22+
def __init__(self):
23+
super(Model, self).__init__()
24+
self.conv = nn.Conv2d(1, 16, 5)
25+
self.pool = nn.MaxPool2d(2, 2)
26+
self.fc = nn.Linear(2304, 10)
27+
28+
def forward(self, x):
29+
x = self.pool(F.relu(self.conv(x)))
30+
x = x.view(-1, 2304)
31+
logit_output = self.fc(x)
32+
output = F.softmax(logit_output, dim=1)
33+
34+
return (logit_output, output)
35+
36+
1737
class TestCarliniL2(unittest.TestCase):
1838
"""
1939
A unittest class for testing the Carlini2 attack.
@@ -145,6 +165,107 @@ def test_krclassifier(self):
145165
y_pred_adv = np.argmax(krc.predict(x_test_adv), axis=1)
146166
self.assertTrue((y_pred != y_pred_adv).any())
147167

168+
def test_ptclassifier(self):
169+
"""
170+
Third test with the PyTorchClassifier.
171+
:return:
172+
"""
173+
# Get MNIST
174+
batch_size, nb_train, nb_test = 100, 1000, 10
175+
(x_train, y_train), (x_test, y_test), _, _ = load_mnist()
176+
x_train, y_train = x_train[:nb_train], np.argmax(y_train[:nb_train], axis=1)
177+
x_test, y_test = x_test[:nb_test], y_test[:nb_test]
178+
x_train = np.swapaxes(x_train, 1, 3)
179+
x_test = np.swapaxes(x_test, 1, 3)
180+
181+
# Create simple CNN
182+
# Define the network
183+
model = Model()
184+
185+
# Define a loss function and optimizer
186+
loss_fn = nn.CrossEntropyLoss()
187+
optimizer = optim.Adam(model.parameters(), lr=0.01)
188+
189+
# Get classifier
190+
ptc = PyTorchClassifier((0, 1), model, loss_fn, optimizer, (1, 28, 28), (10,))
191+
ptc.fit(x_train, y_train, batch_size=batch_size, nb_epochs=1)
192+
193+
# First attack
194+
cl2m = CarliniL2Method(classifier=ptc, targeted=True, max_iter=100, binary_search_steps=10,
195+
learning_rate=2e-2, initial_const=3, decay=1e-2)
196+
params = {'y': random_targets(y_test, ptc.nb_classes)}
197+
x_test_adv = cl2m.generate(x_test, **params)
198+
self.assertFalse((x_test == x_test_adv).all())
199+
target = np.argmax(params['y'], axis=1)
200+
y_pred_adv = np.argmax(ptc.predict(x_test_adv), axis=1)
201+
self.assertTrue((target == y_pred_adv).any())
202+
203+
# Second attack
204+
cl2m = CarliniL2Method(classifier=ptc, targeted=False, max_iter=100, binary_search_steps=10,
205+
learning_rate=2e-2, initial_const=3, decay=1e-2)
206+
params = {'y': random_targets(y_test, ptc.nb_classes)}
207+
x_test_adv = cl2m.generate(x_test, **params)
208+
self.assertFalse((x_test == x_test_adv).all())
209+
target = np.argmax(params['y'], axis=1)
210+
y_pred_adv = np.argmax(ptc.predict(x_test_adv), axis=1)
211+
self.assertTrue((target != y_pred_adv).all())
212+
213+
# Third attack
214+
cl2m = CarliniL2Method(classifier=ptc, targeted=False, max_iter=100, binary_search_steps=10,
215+
learning_rate=2e-2, initial_const=3, decay=1e-2)
216+
params = {}
217+
x_test_adv = cl2m.generate(x_test, **params)
218+
self.assertFalse((x_test == x_test_adv).all())
219+
y_pred = np.argmax(ptc.predict(x_test), axis=1)
220+
y_pred_adv = np.argmax(ptc.predict(x_test_adv), axis=1)
221+
self.assertTrue((y_pred != y_pred_adv).any())
222+
148223

149224
if __name__ == '__main__':
150225
unittest.main()
226+
227+
228+
229+
230+
231+
232+
233+
234+
235+
236+
237+
238+
239+
240+
241+
242+
243+
244+
245+
246+
247+
248+
249+
250+
251+
252+
253+
254+
255+
256+
257+
258+
259+
260+
261+
262+
263+
264+
265+
266+
267+
268+
269+
270+
271+

art/attacks/universal_perturbation_unittest.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -182,41 +182,3 @@ def test_ptclassifier(self):
182182

183183

184184

185-
186-
187-
188-
189-
190-
191-
192-
193-
194-
195-
196-
197-
198-
199-
200-
201-
202-
203-
204-
205-
206-
207-
208-
209-
210-
211-
212-
213-
214-
215-
216-
217-
218-
219-
220-
221-
222-

0 commit comments

Comments
 (0)