|
7 | 7 | import keras.backend as k |
8 | 8 | from keras.models import Sequential |
9 | 9 | from keras.layers import Dense, Flatten, Conv2D, MaxPooling2D |
| 10 | +import torch.nn as nn |
| 11 | +import torch.nn.functional as F |
| 12 | +import torch.optim as optim |
10 | 13 |
|
11 | 14 | from art.attacks.universal_perturbation import UniversalPerturbation |
12 | 15 | from art.classifiers.tensorflow import TFClassifier |
13 | 16 | from art.classifiers.keras import KerasClassifier |
| 17 | +from art.classifiers.pytorch import PyTorchClassifier |
14 | 18 | from art.utils import load_mnist |
15 | 19 |
|
16 | 20 |
|
| 21 | +class Model(nn.Module): |
| 22 | + def __init__(self): |
| 23 | + super(Model, self).__init__() |
| 24 | + self.conv = nn.Conv2d(1, 16, 5) |
| 25 | + self.pool = nn.MaxPool2d(2, 2) |
| 26 | + self.fc = nn.Linear(2304, 10) |
| 27 | + |
| 28 | + def forward(self, x): |
| 29 | + x = self.pool(F.relu(self.conv(x))) |
| 30 | + x = x.view(-1, 2304) |
| 31 | + logit_output = self.fc(x) |
| 32 | + output = F.softmax(logit_output, dim=1) |
| 33 | + |
| 34 | + return (logit_output, output) |
| 35 | + |
| 36 | + |
17 | 37 | class TestUniversalPerturbation(unittest.TestCase): |
18 | 38 | """ |
19 | 39 | A unittest class for testing the UniversalPerturbation attack. |
@@ -115,8 +135,88 @@ def test_krclassifier(self): |
115 | 135 | self.assertFalse((np.argmax(y_test, axis=1) == test_y_pred).all()) |
116 | 136 | self.assertFalse((np.argmax(y_train, axis=1) == train_y_pred).all()) |
117 | 137 |
|
| 138 | + def test_ptclassifier(self): |
| 139 | + """ |
| 140 | + Third test with the PyTorchClassifier. |
| 141 | + :return: |
| 142 | + """ |
| 143 | + # Get MNIST |
| 144 | + batch_size, nb_train, nb_test = 100, 1000, 10 |
| 145 | + (x_train, y_train), (x_test, y_test), _, _ = load_mnist() |
| 146 | + x_train, y_train = x_train[:nb_train], np.argmax(y_train[:nb_train], axis=1) |
| 147 | + x_test, y_test = x_test[:nb_test], np.argmax(y_test[:nb_test], axis=1) |
| 148 | + x_train = np.swapaxes(x_train, 1, 3) |
| 149 | + x_test = np.swapaxes(x_test, 1, 3) |
| 150 | + |
| 151 | + # Create simple CNN |
| 152 | + # Define the network |
| 153 | + model = Model() |
| 154 | + |
| 155 | + # Define a loss function and optimizer |
| 156 | + loss_fn = nn.CrossEntropyLoss() |
| 157 | + optimizer = optim.Adam(model.parameters(), lr=0.01) |
| 158 | + |
| 159 | + # Get classifier |
| 160 | + ptc = PyTorchClassifier((0, 1), model, loss_fn, optimizer, (1, 28, 28), (10,)) |
| 161 | + ptc.fit(x_train, y_train, batch_size=batch_size, nb_epochs=1) |
| 162 | + |
| 163 | + # Attack |
| 164 | + # TODO Launch with all possible attacks |
| 165 | + attack_params = {"attacker": "newtonfool", "attacker_params": {"max_iter": 20}} |
| 166 | + up = UniversalPerturbation(ptc) |
| 167 | + x_train_adv = up.generate(x_train, **attack_params) |
| 168 | + self.assertTrue((up.fooling_rate >= 0.2) or not up.converged) |
| 169 | + |
| 170 | + x_test_adv = x_test + up.v |
| 171 | + self.assertFalse((x_test == x_test_adv).all()) |
| 172 | + |
| 173 | + train_y_pred = np.argmax(ptc.predict(x_train_adv), axis=1) |
| 174 | + test_y_pred = np.argmax(ptc.predict(x_test_adv), axis=1) |
| 175 | + self.assertFalse((y_test == test_y_pred).all()) |
| 176 | + self.assertFalse((y_train == train_y_pred).all()) |
| 177 | + |
118 | 178 |
|
119 | 179 | if __name__ == '__main__': |
120 | 180 | unittest.main() |
121 | 181 |
|
122 | 182 |
|
| 183 | + |
| 184 | + |
| 185 | + |
| 186 | + |
| 187 | + |
| 188 | + |
| 189 | + |
| 190 | + |
| 191 | + |
| 192 | + |
| 193 | + |
| 194 | + |
| 195 | + |
| 196 | + |
| 197 | + |
| 198 | + |
| 199 | + |
| 200 | + |
| 201 | + |
| 202 | + |
| 203 | + |
| 204 | + |
| 205 | + |
| 206 | + |
| 207 | + |
| 208 | + |
| 209 | + |
| 210 | + |
| 211 | + |
| 212 | + |
| 213 | + |
| 214 | + |
| 215 | + |
| 216 | + |
| 217 | + |
| 218 | + |
| 219 | + |
| 220 | + |
| 221 | + |
| 222 | + |
0 commit comments