|
7 | 7 | import keras.backend as k |
8 | 8 | from keras.models import Sequential |
9 | 9 | from keras.layers import Dense, Flatten, Conv2D, MaxPooling2D |
| 10 | +import torch.nn as nn |
| 11 | +import torch.nn.functional as F |
| 12 | +import torch.optim as optim |
10 | 13 |
|
11 | 14 | from art.attacks.carlini import CarliniL2Method |
12 | 15 | from art.classifiers.tensorflow import TFClassifier |
13 | 16 | from art.classifiers.keras import KerasClassifier |
| 17 | +from art.classifiers.pytorch import PyTorchClassifier |
14 | 18 | from art.utils import load_mnist, random_targets |
15 | 19 |
|
16 | 20 |
|
| 21 | +class Model(nn.Module): |
| 22 | + def __init__(self): |
| 23 | + super(Model, self).__init__() |
| 24 | + self.conv = nn.Conv2d(1, 16, 5) |
| 25 | + self.pool = nn.MaxPool2d(2, 2) |
| 26 | + self.fc = nn.Linear(2304, 10) |
| 27 | + |
| 28 | + def forward(self, x): |
| 29 | + x = self.pool(F.relu(self.conv(x))) |
| 30 | + x = x.view(-1, 2304) |
| 31 | + logit_output = self.fc(x) |
| 32 | + output = F.softmax(logit_output, dim=1) |
| 33 | + |
| 34 | + return (logit_output, output) |
| 35 | + |
| 36 | + |
17 | 37 | class TestCarliniL2(unittest.TestCase): |
18 | 38 | """ |
19 | 39 | A unittest class for testing the Carlini2 attack. |
@@ -145,6 +165,107 @@ def test_krclassifier(self): |
145 | 165 | y_pred_adv = np.argmax(krc.predict(x_test_adv), axis=1) |
146 | 166 | self.assertTrue((y_pred != y_pred_adv).any()) |
147 | 167 |
|
| 168 | + def test_ptclassifier(self): |
| 169 | + """ |
| 170 | + Third test with the PyTorchClassifier. |
| 171 | + :return: |
| 172 | + """ |
| 173 | + # Get MNIST |
| 174 | + batch_size, nb_train, nb_test = 100, 1000, 10 |
| 175 | + (x_train, y_train), (x_test, y_test), _, _ = load_mnist() |
| 176 | + x_train, y_train = x_train[:nb_train], np.argmax(y_train[:nb_train], axis=1) |
| 177 | + x_test, y_test = x_test[:nb_test], y_test[:nb_test] |
| 178 | + x_train = np.swapaxes(x_train, 1, 3) |
| 179 | + x_test = np.swapaxes(x_test, 1, 3) |
| 180 | + |
| 181 | + # Create simple CNN |
| 182 | + # Define the network |
| 183 | + model = Model() |
| 184 | + |
| 185 | + # Define a loss function and optimizer |
| 186 | + loss_fn = nn.CrossEntropyLoss() |
| 187 | + optimizer = optim.Adam(model.parameters(), lr=0.01) |
| 188 | + |
| 189 | + # Get classifier |
| 190 | + ptc = PyTorchClassifier((0, 1), model, loss_fn, optimizer, (1, 28, 28), (10,)) |
| 191 | + ptc.fit(x_train, y_train, batch_size=batch_size, nb_epochs=1) |
| 192 | + |
| 193 | + # First attack |
| 194 | + cl2m = CarliniL2Method(classifier=ptc, targeted=True, max_iter=100, binary_search_steps=10, |
| 195 | + learning_rate=2e-2, initial_const=3, decay=1e-2) |
| 196 | + params = {'y': random_targets(y_test, ptc.nb_classes)} |
| 197 | + x_test_adv = cl2m.generate(x_test, **params) |
| 198 | + self.assertFalse((x_test == x_test_adv).all()) |
| 199 | + target = np.argmax(params['y'], axis=1) |
| 200 | + y_pred_adv = np.argmax(ptc.predict(x_test_adv), axis=1) |
| 201 | + self.assertTrue((target == y_pred_adv).any()) |
| 202 | + |
| 203 | + # Second attack |
| 204 | + cl2m = CarliniL2Method(classifier=ptc, targeted=False, max_iter=100, binary_search_steps=10, |
| 205 | + learning_rate=2e-2, initial_const=3, decay=1e-2) |
| 206 | + params = {'y': random_targets(y_test, ptc.nb_classes)} |
| 207 | + x_test_adv = cl2m.generate(x_test, **params) |
| 208 | + self.assertFalse((x_test == x_test_adv).all()) |
| 209 | + target = np.argmax(params['y'], axis=1) |
| 210 | + y_pred_adv = np.argmax(ptc.predict(x_test_adv), axis=1) |
| 211 | + self.assertTrue((target != y_pred_adv).all()) |
| 212 | + |
| 213 | + # Third attack |
| 214 | + cl2m = CarliniL2Method(classifier=ptc, targeted=False, max_iter=100, binary_search_steps=10, |
| 215 | + learning_rate=2e-2, initial_const=3, decay=1e-2) |
| 216 | + params = {} |
| 217 | + x_test_adv = cl2m.generate(x_test, **params) |
| 218 | + self.assertFalse((x_test == x_test_adv).all()) |
| 219 | + y_pred = np.argmax(ptc.predict(x_test), axis=1) |
| 220 | + y_pred_adv = np.argmax(ptc.predict(x_test_adv), axis=1) |
| 221 | + self.assertTrue((y_pred != y_pred_adv).any()) |
| 222 | + |
148 | 223 |
|
149 | 224 | if __name__ == '__main__': |
150 | 225 | unittest.main() |
| 226 | + |
| 227 | + |
| 228 | + |
| 229 | + |
| 230 | + |
| 231 | + |
| 232 | + |
| 233 | + |
| 234 | + |
| 235 | + |
| 236 | + |
| 237 | + |
| 238 | + |
| 239 | + |
| 240 | + |
| 241 | + |
| 242 | + |
| 243 | + |
| 244 | + |
| 245 | + |
| 246 | + |
| 247 | + |
| 248 | + |
| 249 | + |
| 250 | + |
| 251 | + |
| 252 | + |
| 253 | + |
| 254 | + |
| 255 | + |
| 256 | + |
| 257 | + |
| 258 | + |
| 259 | + |
| 260 | + |
| 261 | + |
| 262 | + |
| 263 | + |
| 264 | + |
| 265 | + |
| 266 | + |
| 267 | + |
| 268 | + |
| 269 | + |
| 270 | + |
| 271 | + |
0 commit comments