@@ -122,7 +122,10 @@ def __init__(
122
122
:param cost_multiplier: How much to change the cost in the Neural Cleanse optimization
123
123
:param batch_size: The batch size for optimizations in the Neural Cleanse optimization
124
124
"""
125
+ import tensorflow as tf
126
+ from tensorflow .keras .layers import Lambda
125
127
import keras .backend as K
128
+ from keras .optimizers import Adam
126
129
from keras .losses import categorical_crossentropy
127
130
from keras .metrics import categorical_accuracy
128
131
@@ -153,50 +156,66 @@ def __init__(
153
156
self .epsilon = K .epsilon ()
154
157
155
158
# Normalize mask between [0, 1]
156
- self .mask_tensor_raw = K .variable (mask )
157
- # self.mask_tensor = K.expand_dims(K.tanh(self.mask_tensor_raw) / (2 - self.epsilon) + 0.5, axis=0)
158
- self .mask_tensor = K .tanh (self .mask_tensor_raw ) / (2 - self .epsilon ) + 0.5
159
+ self .mask_tensor_raw = tf .Variable (mask , dtype = tf .float32 )
160
+ # self.mask_tensor = tf.math.tanh(self.mask_tensor_raw) / (2.0 - self.epsilon) + 0.5
159
161
160
162
# Normalize pattern between [0, 1]
161
- self .pattern_tensor_raw = K . variable (pattern )
162
- self .pattern_tensor = K .expand_dims (K .tanh (self .pattern_tensor_raw ) / (2 - self .epsilon ) + 0.5 , axis = 0 )
163
+ self .pattern_tensor_raw = tf . Variable (pattern , dtype = tf . float32 )
164
+ # self.pattern_tensor = tf .expand_dims(tf.math .tanh(self.pattern_tensor_raw) / (2 - self.epsilon) + 0.5, axis=0)
163
165
164
- reverse_mask_tensor = K .ones_like (self .mask_tensor ) - self .mask_tensor
165
- input_tensor = K .placeholder (model .input_shape )
166
- x_adv_tensor = reverse_mask_tensor * input_tensor + self .mask_tensor * self .pattern_tensor
166
+ # @tf.function
167
+ def train_step (x_batch , y_batch ):
168
+ with tf .GradientTape () as tape :
169
+ # Normalize mask and pattern
170
+ self .mask_tensor = tf .tanh (self .mask_tensor_raw ) / (2 - self .epsilon ) + 0.5
171
+ self .pattern_tensor = tf .tanh (self .pattern_tensor_raw ) / (2 - self .epsilon ) + 0.5
167
172
168
- output_tensor = self .model (x_adv_tensor )
169
- y_true_tensor = K .placeholder (model .outputs [0 ].shape .as_list ())
173
+ # Construct adversarial example
174
+ reverse_mask_tensor = 1.0 - self .mask_tensor
175
+ x_adv = reverse_mask_tensor * x_batch + self .mask_tensor * self .pattern_tensor
170
176
171
- self . loss_acc = categorical_accuracy ( output_tensor , y_true_tensor )
172
- self . loss_ce = categorical_crossentropy ( output_tensor , y_true_tensor )
177
+ # Forward pass
178
+ y_pred = self . model ( x_adv , training = False )
173
179
174
- if self .norm == 1 :
175
- # TODO: change 3 to dynamically set img_color
176
- self .loss_reg = K .sum (K .abs (self .mask_tensor )) / 3
177
- elif self .norm == 2 :
178
- self .loss_reg = K .sqrt (K .sum (K .square (self .mask_tensor )) / 3 )
180
+ # Classification loss
181
+ loss_ce = tf .keras .losses .categorical_crossentropy (y_batch , y_pred , from_logits = self .use_logits )
179
182
180
- self . cost = self . init_cost
181
- self . cost_tensor = K . variable ( self . cost )
182
- self . loss_combined = self . loss_ce + self . loss_reg * self . cost_tensor
183
+ # Accuracy
184
+ correct = tf . equal ( tf . argmax ( y_pred , axis = 1 ), tf . argmax ( y_batch , axis = 1 ) )
185
+ loss_acc = tf . reduce_mean ( tf . cast ( correct , tf . float32 ))
183
186
184
- try :
185
- from keras .optimizers import Adam
187
+ # Regularization loss
188
+ if self .norm == 1 :
189
+ loss_reg = tf .reduce_sum (tf .abs (self .mask_tensor )) / tf .cast (
190
+ tf .shape (self .mask_tensor )[- 1 ], tf .float32
191
+ )
192
+ elif self .norm == 2 :
193
+ loss_reg = tf .sqrt (
194
+ tf .reduce_sum (tf .square (self .mask_tensor )) / tf .cast (tf .shape (self .mask_tensor )[- 1 ], tf .float32 )
195
+ )
196
+ else :
197
+ raise ValueError (f"Unsupported norm { self .norm } " )
186
198
187
- self .opt = Adam (lr = self .learning_rate , beta_1 = 0.5 , beta_2 = 0.9 )
188
- except ImportError :
189
- from keras .optimizers import adam_v2
199
+ # Total loss
200
+ loss_combined = tf .reduce_mean (loss_ce ) + self .cost * loss_reg
190
201
191
- self .opt = adam_v2 .Adam (lr = self .learning_rate , beta_1 = 0.5 , beta_2 = 0.9 )
192
- self .updates = self .opt .get_updates (
193
- params = [self .pattern_tensor_raw , self .mask_tensor_raw ], loss = self .loss_combined
194
- )
195
- self .train = K .function (
196
- [input_tensor , y_true_tensor ],
197
- [self .loss_ce , self .loss_reg , self .loss_combined , self .loss_acc ],
198
- updates = self .updates ,
199
- )
202
+ # Compute gradients
203
+ grads = tape .gradient (loss_combined , [self .mask_tensor_raw , self .pattern_tensor_raw ])
204
+
205
+ # Apply gradients
206
+ self .opt .apply_gradients (zip (grads , [self .mask_tensor_raw , self .pattern_tensor_raw ]))
207
+
208
+ print (loss_acc )
209
+
210
+ return loss_ce , loss_reg , loss_combined , loss_acc
211
+
212
+ self .train = train_step
213
+
214
+ # Initialize cost (as a TensorFlow variable so it can be updated during training)
215
+ self .cost = self .init_cost
216
+ self .cost_tensor = tf .Variable (self .cost , trainable = False , dtype = tf .float32 )
217
+
218
+ self .opt = Adam (learning_rate = self .learning_rate , beta_1 = 0.5 , beta_2 = 0.9 )
200
219
201
220
@property
202
221
def input_shape (self ) -> tuple [int , ...]:
@@ -212,13 +231,14 @@ def reset(self):
212
231
Reset the state of the defense
213
232
:return:
214
233
"""
215
- import keras . backend as K
234
+ import tensorflow as tf
216
235
217
236
self .cost = self .init_cost
218
- K .set_value (self .cost_tensor , self .init_cost )
219
- K .set_value (self .opt .iterations , 0 )
220
- for weight in self .opt .weights :
221
- K .set_value (weight , np .zeros (K .int_shape (weight )))
237
+ self .cost_tensor .assign (self .init_cost )
238
+ self .opt .iterations .assign (0 )
239
+ if self .opt ._variables :
240
+ for var in self .opt ._variables :
241
+ var .assign (tf .zeros_like (var ))
222
242
223
243
def generate_backdoor (
224
244
self , x_val : np .ndarray , y_val : np .ndarray , y_target : np .ndarray
@@ -227,8 +247,9 @@ def generate_backdoor(
227
247
Generates a possible backdoor for the model. Returns the pattern and the mask
228
248
:return: A tuple of the pattern and mask for the model.
229
249
"""
250
+ import tensorflow as tf
230
251
import keras .backend as K
231
- from keras .preprocessing .image import ImageDataGenerator
252
+ from tensorflow . keras .preprocessing .image import ImageDataGenerator
232
253
233
254
self .reset ()
234
255
datagen = ImageDataGenerator ()
@@ -249,20 +270,20 @@ def generate_backdoor(
249
270
loss_acc_list = []
250
271
251
272
for _ in range (mini_batch_size ):
252
- x_batch , _ = gen . next ()
273
+ x_batch , _ = next (gen )
253
274
y_batch = [y_target ] * x_batch .shape [0 ]
254
- _ , batch_loss_reg , _ , batch_loss_acc = self .train ([ x_batch , y_batch ] )
275
+ _ , batch_loss_reg , _ , batch_loss_acc = self .train (x_batch , y_batch )
255
276
256
- loss_reg_list .extend (list (batch_loss_reg . flatten ()))
257
- loss_acc_list .extend (list (batch_loss_acc . flatten ()))
277
+ loss_reg_list .extend (list (tf . reshape ( batch_loss_reg , [ - 1 ]). numpy ()))
278
+ loss_acc_list .extend (list (tf . reshape ( batch_loss_acc , [ - 1 ]). numpy ()))
258
279
259
280
avg_loss_reg = np .mean (loss_reg_list )
260
281
avg_loss_acc = np .mean (loss_acc_list )
261
282
262
283
# save best mask/pattern so far
263
284
if avg_loss_acc >= self .attack_success_threshold and avg_loss_reg < reg_best :
264
- mask_best = K . eval ( self .mask_tensor )
265
- pattern_best = K . eval ( self .pattern_tensor )
285
+ mask_best = self .mask_tensor . numpy ( )
286
+ pattern_best = self .pattern_tensor . numpy ( )
266
287
reg_best = avg_loss_reg
267
288
268
289
# check early stop
@@ -283,7 +304,7 @@ def generate_backdoor(
283
304
cost_set_counter += 1
284
305
if cost_set_counter >= self .patience :
285
306
self .cost = self .init_cost
286
- K . set_value ( self .cost_tensor , self .cost )
307
+ self .cost_tensor . assign ( self .cost )
287
308
cost_up_counter = 0
288
309
cost_down_counter = 0
289
310
cost_up_flag = False
@@ -301,17 +322,17 @@ def generate_backdoor(
301
322
if cost_up_counter >= self .patience :
302
323
cost_up_counter = 0
303
324
self .cost *= self .cost_multiplier_up
304
- K . set_value ( self .cost_tensor , self .cost )
325
+ self .cost_tensor . assign ( self .cost )
305
326
cost_up_flag = True
306
327
elif cost_down_counter >= self .patience :
307
328
cost_down_counter = 0
308
329
self .cost /= self .cost_multiplier_down
309
- K . set_value ( self .cost_tensor , self .cost )
330
+ self .cost_tensor . assign ( self .cost )
310
331
cost_down_flag = True
311
332
312
333
if mask_best is None :
313
- mask_best = K . eval ( self .mask_tensor )
314
- pattern_best = K . eval ( self .pattern_tensor )
334
+ mask_best = self .mask_tensor . numpy ( )
335
+ pattern_best = self .pattern_tensor . numpy ( )
315
336
316
337
if pattern_best is None :
317
338
raise ValueError ("Unexpected `None` detected." )
0 commit comments