Skip to content

Commit 5501bba

Browse files
committed
unittest newtonfool-pytorch
1 parent a1f1ff2 commit 5501bba

File tree

1 file changed

+87
-0
lines changed

1 file changed

+87
-0
lines changed

art/attacks/newtonfool_unittest.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,34 @@
66
from keras.models import Sequential
77
from keras.layers import Dense, Flatten, Conv2D, MaxPooling2D
88
import tensorflow as tf
9+
import numpy as np
10+
import torch.nn as nn
11+
import torch.nn.functional as F
12+
import torch.optim as optim
913

1014
from art.attacks.newtonfool import NewtonFool
1115
from art.classifiers.tensorflow import TFClassifier
1216
from art.classifiers.keras import KerasClassifier
17+
from art.classifiers.pytorch import PyTorchClassifier
1318
from art.utils import load_mnist
1419

1520

21+
class Model(nn.Module):
22+
def __init__(self):
23+
super(Model, self).__init__()
24+
self.conv = nn.Conv2d(1, 16, 5)
25+
self.pool = nn.MaxPool2d(2, 2)
26+
self.fc = nn.Linear(2304, 10)
27+
28+
def forward(self, x):
29+
x = self.pool(F.relu(self.conv(x)))
30+
x = x.view(-1, 2304)
31+
logit_output = self.fc(x)
32+
output = F.softmax(logit_output, dim=1)
33+
34+
return (logit_output, output)
35+
36+
1637
class TestNewtonFool(unittest.TestCase):
1738
"""
1839
A unittest class for testing the NewtonFool attack.
@@ -106,8 +127,74 @@ def test_krclassifier(self):
106127
y_pred_adv_max = y_pred_adv[y_pred_bool]
107128
self.assertTrue((y_pred_max >= y_pred_adv_max).all())
108129

130+
def test_ptclassifier(self):
131+
"""
132+
Third test with the PyTorchClassifier.
133+
:return:
134+
"""
135+
# Get MNIST
136+
batch_size, nb_train, nb_test = 100, 1000, 10
137+
(x_train, y_train), (x_test, y_test), _, _ = load_mnist()
138+
x_train, y_train = x_train[:nb_train], np.argmax(y_train[:nb_train], axis=1)
139+
x_test, y_test = x_test[:nb_test], np.argmax(y_test[:nb_test], axis=1)
140+
x_train = np.swapaxes(x_train, 1, 3)
141+
x_test = np.swapaxes(x_test, 1, 3)
142+
143+
# Create simple CNN
144+
# Define the network
145+
model = Model()
146+
147+
# Define a loss function and optimizer
148+
loss_fn = nn.CrossEntropyLoss()
149+
optimizer = optim.Adam(model.parameters(), lr=0.01)
150+
151+
# Get classifier
152+
ptc = PyTorchClassifier((0, 1), model, loss_fn, optimizer, (1, 28, 28), (10,))
153+
ptc.fit(x_train, y_train, batch_size=batch_size, nb_epochs=1)
154+
155+
# Attack
156+
nf = NewtonFool(ptc)
157+
nf.set_params(max_iter=5)
158+
x_test_adv = nf.generate(x_test)
159+
self.assertFalse((x_test == x_test_adv).all())
160+
161+
y_pred = ptc.predict(x_test)
162+
y_pred_adv = ptc.predict(x_test_adv)
163+
y_pred_bool = y_pred.max(axis=1, keepdims=1) == y_pred
164+
y_pred_max = y_pred.max(axis=1)
165+
y_pred_adv_max = y_pred_adv[y_pred_bool]
166+
self.assertTrue((y_pred_max >= y_pred_adv_max).all())
167+
109168

110169
if __name__ == '__main__':
111170
unittest.main()
112171

113172

173+
174+
175+
176+
177+
178+
179+
180+
181+
182+
183+
184+
185+
186+
187+
188+
189+
190+
191+
192+
193+
194+
195+
196+
197+
198+
199+
200+

0 commit comments

Comments
 (0)