2424import keras .backend as k
2525from keras .models import Sequential
2626from keras .layers import Dense , Flatten , Conv2D , MaxPooling2D
27+ import torch .nn as nn
28+ import torch .nn .functional as F
29+ import torch .optim as optim
2730
2831from art .attacks .carlini import CarliniL2Method
2932from art .classifiers .tensorflow import TFClassifier
3033from art .classifiers .keras import KerasClassifier
34+ from art .classifiers .pytorch import PyTorchClassifier
3135from art .utils import load_mnist , random_targets
3236
3337
38+ class Model (nn .Module ):
39+ def __init__ (self ):
40+ super (Model , self ).__init__ ()
41+ self .conv = nn .Conv2d (1 , 16 , 5 )
42+ self .pool = nn .MaxPool2d (2 , 2 )
43+ self .fc = nn .Linear (2304 , 10 )
44+
45+ def forward (self , x ):
46+ x = self .pool (F .relu (self .conv (x )))
47+ x = x .view (- 1 , 2304 )
48+ logit_output = self .fc (x )
49+ output = F .softmax (logit_output , dim = 1 )
50+
51+ return logit_output , output
52+
53+
3454class TestCarliniL2 (unittest .TestCase ):
3555 """
3656 A unittest class for testing the Carlini2 attack.
3757 """
58+ def test_failure_attack (self ):
59+ """
60+ Test the corner case when attack is failed.
61+ :return:
62+ """
63+ # Build a TFClassifier
64+ # Define input and output placeholders
65+ self ._input_ph = tf .placeholder (tf .float32 , shape = [None , 28 , 28 , 1 ])
66+ self ._output_ph = tf .placeholder (tf .int32 , shape = [None , 10 ])
67+
68+ # Define the tensorflow graph
69+ conv = tf .layers .conv2d (self ._input_ph , 4 , 5 , activation = tf .nn .relu )
70+ conv = tf .layers .max_pooling2d (conv , 2 , 2 )
71+ fc = tf .contrib .layers .flatten (conv )
72+
73+ # Logits layer
74+ self ._logits = tf .layers .dense (fc , 10 )
75+
76+ # Train operator
77+ self ._loss = tf .reduce_mean (tf .losses .softmax_cross_entropy (logits = self ._logits , onehot_labels = self ._output_ph ))
78+ optimizer = tf .train .AdamOptimizer (learning_rate = 0.01 )
79+ self ._train = optimizer .minimize (self ._loss )
80+
81+ # Tensorflow session and initialization
82+ self ._sess = tf .Session ()
83+ self ._sess .run (tf .global_variables_initializer ())
84+
85+ # Get MNIST
86+ batch_size , nb_train , nb_test = 100 , 1000 , 10
87+ (x_train , y_train ), (x_test , y_test ), _ , _ = load_mnist ()
88+ x_train , y_train = x_train [:nb_train ], y_train [:nb_train ]
89+ x_test , y_test = x_test [:nb_test ], y_test [:nb_test ]
90+
91+ # Train the classifier
92+ tfc = TFClassifier ((0 , 1 ), self ._input_ph , self ._logits , self ._output_ph ,
93+ self ._train , self ._loss , None , self ._sess )
94+ tfc .fit (x_train , y_train , batch_size = batch_size , nb_epochs = 2 )
95+
96+ # Failure attack
97+ cl2m = CarliniL2Method (classifier = tfc , targeted = True , max_iter = 0 , binary_search_steps = 0 ,
98+ learning_rate = 2e-2 , initial_const = 3 , decay = 1e-2 )
99+ params = {'y' : random_targets (y_test , tfc .nb_classes )}
100+ x_test_adv = cl2m .generate (x_test , ** params )
101+ self .assertTrue ((x_test_adv <= 1.0001 ).all ())
102+ self .assertTrue ((x_test_adv >= - 0.0001 ).all ())
103+ np .testing .assert_almost_equal (x_test , x_test_adv , 3 )
104+
38105 def test_tfclassifier (self ):
39106 """
40107 First test with the TFClassifier.
@@ -63,7 +130,7 @@ def test_tfclassifier(self):
63130 self ._sess .run (tf .global_variables_initializer ())
64131
65132 # Get MNIST
66- batch_size , nb_train , nb_test = 100 , 1000 , 10
133+ batch_size , nb_train , nb_test = 100 , 500 , 5
67134 (x_train , y_train ), (x_test , y_test ), _ , _ = load_mnist ()
68135 x_train , y_train = x_train [:nb_train ], y_train [:nb_train ]
69136 x_test , y_test = x_test [:nb_test ], y_test [:nb_test ]
@@ -74,31 +141,38 @@ def test_tfclassifier(self):
74141 tfc .fit (x_train , y_train , batch_size = batch_size , nb_epochs = 2 )
75142
76143 # First attack
77- cl2m = CarliniL2Method (classifier = tfc , targeted = True , max_iter = 100 , binary_search_steps = 10 ,
144+ cl2m = CarliniL2Method (classifier = tfc , targeted = True , max_iter = 10 , binary_search_steps = 10 ,
78145 learning_rate = 2e-2 , initial_const = 3 , decay = 1e-2 )
79146 params = {'y' : random_targets (y_test , tfc .nb_classes )}
80147 x_test_adv = cl2m .generate (x_test , ** params )
81148 self .assertFalse ((x_test == x_test_adv ).all ())
149+ #print(x_test_adv)
150+ self .assertTrue ((x_test_adv <= 1.0001 ).all ())
151+ self .assertTrue ((x_test_adv >= - 0.0001 ).all ())
82152 target = np .argmax (params ['y' ], axis = 1 )
83153 y_pred_adv = np .argmax (tfc .predict (x_test_adv ), axis = 1 )
84154 self .assertTrue ((target == y_pred_adv ).all ())
85155
86156 # Second attack
87- cl2m = CarliniL2Method (classifier = tfc , targeted = False , max_iter = 100 , binary_search_steps = 10 ,
157+ cl2m = CarliniL2Method (classifier = tfc , targeted = False , max_iter = 10 , binary_search_steps = 10 ,
88158 learning_rate = 2e-2 , initial_const = 3 , decay = 1e-2 )
89159 params = {'y' : random_targets (y_test , tfc .nb_classes )}
90160 x_test_adv = cl2m .generate (x_test , ** params )
91161 self .assertFalse ((x_test == x_test_adv ).all ())
162+ self .assertTrue ((x_test_adv <= 1.0001 ).all ())
163+ self .assertTrue ((x_test_adv >= - 0.0001 ).all ())
92164 target = np .argmax (params ['y' ], axis = 1 )
93165 y_pred_adv = np .argmax (tfc .predict (x_test_adv ), axis = 1 )
94166 self .assertTrue ((target != y_pred_adv ).all ())
95167
96168 # Third attack
97- cl2m = CarliniL2Method (classifier = tfc , targeted = False , max_iter = 100 , binary_search_steps = 10 ,
169+ cl2m = CarliniL2Method (classifier = tfc , targeted = False , max_iter = 10 , binary_search_steps = 10 ,
98170 learning_rate = 2e-2 , initial_const = 3 , decay = 1e-2 )
99171 params = {}
100172 x_test_adv = cl2m .generate (x_test , ** params )
101173 self .assertFalse ((x_test == x_test_adv ).all ())
174+ self .assertTrue ((x_test_adv <= 1.0001 ).all ())
175+ self .assertTrue ((x_test_adv >= - 0.0001 ).all ())
102176 y_pred = np .argmax (tfc .predict (x_test ), axis = 1 )
103177 y_pred_adv = np .argmax (tfc .predict (x_test_adv ), axis = 1 )
104178 self .assertTrue ((y_pred != y_pred_adv ).all ())
@@ -113,7 +187,7 @@ def test_krclassifier(self):
113187 k .set_session (session )
114188
115189 # Get MNIST
116- batch_size , nb_train , nb_test = 100 , 1000 , 10
190+ batch_size , nb_train , nb_test = 100 , 500 , 5
117191 (x_train , y_train ), (x_test , y_test ), _ , _ = load_mnist ()
118192 x_train , y_train = x_train [:nb_train ], y_train [:nb_train ]
119193 x_test , y_test = x_test [:nb_test ], y_test [:nb_test ]
@@ -133,35 +207,102 @@ def test_krclassifier(self):
133207 krc .fit (x_train , y_train , batch_size = batch_size , nb_epochs = 2 )
134208
135209 # First attack
136- cl2m = CarliniL2Method (classifier = krc , targeted = True , max_iter = 100 , binary_search_steps = 10 ,
210+ cl2m = CarliniL2Method (classifier = krc , targeted = True , max_iter = 10 , binary_search_steps = 10 ,
137211 learning_rate = 2e-2 , initial_const = 3 , decay = 1e-2 )
138212 params = {'y' : random_targets (y_test , krc .nb_classes )}
139213 x_test_adv = cl2m .generate (x_test , ** params )
140214 self .assertFalse ((x_test == x_test_adv ).all ())
215+ self .assertTrue ((x_test_adv <= 1.0001 ).all ())
216+ self .assertTrue ((x_test_adv >= - 0.0001 ).all ())
141217 target = np .argmax (params ['y' ], axis = 1 )
142218 y_pred_adv = np .argmax (krc .predict (x_test_adv ), axis = 1 )
143219 self .assertTrue ((target == y_pred_adv ).any ())
144220
145221 # Second attack
146- cl2m = CarliniL2Method (classifier = krc , targeted = False , max_iter = 100 , binary_search_steps = 10 ,
222+ cl2m = CarliniL2Method (classifier = krc , targeted = False , max_iter = 10 , binary_search_steps = 10 ,
147223 learning_rate = 2e-2 , initial_const = 3 , decay = 1e-2 )
148224 params = {'y' : random_targets (y_test , krc .nb_classes )}
149225 x_test_adv = cl2m .generate (x_test , ** params )
150226 self .assertFalse ((x_test == x_test_adv ).all ())
227+ self .assertTrue ((x_test_adv <= 1.0001 ).all ())
228+ self .assertTrue ((x_test_adv >= - 0.0001 ).all ())
151229 target = np .argmax (params ['y' ], axis = 1 )
152230 y_pred_adv = np .argmax (krc .predict (x_test_adv ), axis = 1 )
153231 self .assertTrue ((target != y_pred_adv ).all ())
154232
155233 # Third attack
156- cl2m = CarliniL2Method (classifier = krc , targeted = False , max_iter = 100 , binary_search_steps = 10 ,
234+ cl2m = CarliniL2Method (classifier = krc , targeted = False , max_iter = 10 , binary_search_steps = 10 ,
157235 learning_rate = 2e-2 , initial_const = 3 , decay = 1e-2 )
158236 params = {}
159237 x_test_adv = cl2m .generate (x_test , ** params )
160238 self .assertFalse ((x_test == x_test_adv ).all ())
239+ self .assertTrue ((x_test_adv <= 1.0001 ).all ())
240+ self .assertTrue ((x_test_adv >= - 0.0001 ).all ())
161241 y_pred = np .argmax (krc .predict (x_test ), axis = 1 )
162242 y_pred_adv = np .argmax (krc .predict (x_test_adv ), axis = 1 )
163243 self .assertTrue ((y_pred != y_pred_adv ).any ())
164244
245+ def test_ptclassifier (self ):
246+ """
247+ Third test with the PyTorchClassifier.
248+ :return:
249+ """
250+ # Get MNIST
251+ batch_size , nb_train , nb_test = 100 , 1000 , 10
252+ (x_train , y_train ), (x_test , y_test ), _ , _ = load_mnist ()
253+ x_train , y_train = x_train [:nb_train ], np .argmax (y_train [:nb_train ], axis = 1 )
254+ x_test , y_test = x_test [:nb_test ], y_test [:nb_test ]
255+ x_train = np .swapaxes (x_train , 1 , 3 )
256+ x_test = np .swapaxes (x_test , 1 , 3 )
257+
258+ # Create simple CNN
259+ # Define the network
260+ model = Model ()
261+
262+ # Define a loss function and optimizer
263+ loss_fn = nn .CrossEntropyLoss ()
264+ optimizer = optim .Adam (model .parameters (), lr = 0.01 )
265+
266+ # Get classifier
267+ ptc = PyTorchClassifier ((0 , 1 ), model , loss_fn , optimizer , (1 , 28 , 28 ), (10 ,))
268+ ptc .fit (x_train , y_train , batch_size = batch_size , nb_epochs = 1 )
269+
270+ # First attack
271+ cl2m = CarliniL2Method (classifier = ptc , targeted = True , max_iter = 100 , binary_search_steps = 10 ,
272+ learning_rate = 2e-2 , initial_const = 3 , decay = 1e-2 )
273+ params = {'y' : random_targets (y_test , ptc .nb_classes )}
274+ x_test_adv = cl2m .generate (x_test , ** params )
275+ self .assertFalse ((x_test == x_test_adv ).all ())
276+ self .assertTrue ((x_test_adv <= 1.0001 ).all ())
277+ self .assertTrue ((x_test_adv >= - 0.0001 ).all ())
278+ target = np .argmax (params ['y' ], axis = 1 )
279+ y_pred_adv = np .argmax (ptc .predict (x_test_adv ), axis = 1 )
280+ self .assertTrue ((target == y_pred_adv ).any ())
281+
282+ # Second attack
283+ cl2m = CarliniL2Method (classifier = ptc , targeted = False , max_iter = 100 , binary_search_steps = 10 ,
284+ learning_rate = 2e-2 , initial_const = 3 , decay = 1e-2 )
285+ params = {'y' : random_targets (y_test , ptc .nb_classes )}
286+ x_test_adv = cl2m .generate (x_test , ** params )
287+ self .assertFalse ((x_test == x_test_adv ).all ())
288+ self .assertTrue ((x_test_adv <= 1.0001 ).all ())
289+ self .assertTrue ((x_test_adv >= - 0.0001 ).all ())
290+ target = np .argmax (params ['y' ], axis = 1 )
291+ y_pred_adv = np .argmax (ptc .predict (x_test_adv ), axis = 1 )
292+ self .assertTrue ((target != y_pred_adv ).all ())
293+
294+ # Third attack
295+ cl2m = CarliniL2Method (classifier = ptc , targeted = False , max_iter = 100 , binary_search_steps = 10 ,
296+ learning_rate = 2e-2 , initial_const = 3 , decay = 1e-2 )
297+ params = {}
298+ x_test_adv = cl2m .generate (x_test , ** params )
299+ self .assertFalse ((x_test == x_test_adv ).all ())
300+ self .assertTrue ((x_test_adv <= 1.0001 ).all ())
301+ self .assertTrue ((x_test_adv >= - 0.0001 ).all ())
302+ y_pred = np .argmax (ptc .predict (x_test ), axis = 1 )
303+ y_pred_adv = np .argmax (ptc .predict (x_test_adv ), axis = 1 )
304+ self .assertTrue ((y_pred != y_pred_adv ).any ())
305+
165306
166307if __name__ == '__main__' :
167308 unittest .main ()
0 commit comments