Skip to content

Commit 0f0a2cc

Browse files
committed
Fix dependencies
Signed-off-by: Beat Buesser <[email protected]>
1 parent 9b4b84d commit 0f0a2cc

File tree

5 files changed

+88
-125
lines changed

5 files changed

+88
-125
lines changed

art/attacks/poisoning/sleeper_agent_attack.py

Lines changed: 11 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@
3131

3232
from art.attacks.poisoning.gradient_matching_attack import GradientMatchingAttack
3333
from art.estimators.classification.pytorch import PyTorchClassifier
34-
from art.estimators.classification import TensorFlowV2Classifier
3534
from art.preprocessing.standardisation_mean_std.pytorch import StandardisationMeanStdPyTorch
36-
from art.preprocessing.standardisation_mean_std.tensorflow import StandardisationMeanStdTensorFlow
3735

3836

3937
if TYPE_CHECKING:
@@ -99,15 +97,15 @@ def __init__(
9997
:param class_target: The target label to which the poisoned model needs to misclassify.
10098
:param retrain_batch_size: Batch size required for model retraining.
10199
"""
102-
if isinstance(classifier.preprocessing, (StandardisationMeanStdPyTorch, StandardisationMeanStdTensorFlow)):
100+
if isinstance(classifier.preprocessing, StandardisationMeanStdPyTorch):
103101
clip_values_normalised = (
104102
classifier.clip_values - classifier.preprocessing.mean # type: ignore
105103
) / classifier.preprocessing.std
106104
clip_values_normalised = (clip_values_normalised[0], clip_values_normalised[1])
107105
epsilon_normalised = epsilon * (clip_values_normalised[1] - clip_values_normalised[0]) # type: ignore
108106
patch_normalised = (patch - classifier.preprocessing.mean) / classifier.preprocessing.std
109107
else:
110-
raise ValueError("classifier.preprocessing not an instance of pytorch/tensorflow")
108+
raise ValueError("classifier.preprocessing not an instance of pytorch")
111109

112110
super().__init__(
113111
classifier,
@@ -157,9 +155,7 @@ def poison( # type: ignore
157155
"""
158156
# Apply Normalisation
159157
x_train = np.copy(x_train)
160-
if isinstance(
161-
self.substitute_classifier.preprocessing, (StandardisationMeanStdPyTorch, StandardisationMeanStdTensorFlow)
162-
):
158+
if isinstance(self.substitute_classifier.preprocessing, StandardisationMeanStdPyTorch):
163159
x_trigger = (
164160
x_trigger - self.substitute_classifier.preprocessing.mean
165161
) / self.substitute_classifier.preprocessing.std
@@ -172,12 +168,8 @@ def poison( # type: ignore
172168
poisoner = self._poison__pytorch
173169
finish_poisoning = self._finish_poison_pytorch
174170
initializer = self._initialize_poison_pytorch
175-
elif isinstance(self.substitute_classifier, TensorFlowV2Classifier):
176-
poisoner = self._poison__tensorflow
177-
finish_poisoning = self._finish_poison_tensorflow
178-
initializer = self._initialize_poison_tensorflow
179171
else:
180-
raise NotImplementedError("SleeperAgentAttack is currently implemented only for PyTorch and TensorFlowV2.")
172+
raise NotImplementedError("SleeperAgentAttack is currently implemented only for PyTorch.")
181173

182174
# Choose samples to poison.
183175
x_trigger = self._apply_trigger_patch(x_trigger)
@@ -237,9 +229,7 @@ def poison( # type: ignore
237229
self.indices_poison = best_indices_poison
238230

239231
# Apply De-Normalization
240-
if isinstance(
241-
self.substitute_classifier.preprocessing, (StandardisationMeanStdPyTorch, StandardisationMeanStdTensorFlow)
242-
):
232+
if isinstance(self.substitute_classifier.preprocessing, StandardisationMeanStdPyTorch):
243233
x_train = (
244234
x_train * self.substitute_classifier.preprocessing.std + self.substitute_classifier.preprocessing.mean
245235
)
@@ -251,10 +241,8 @@ def poison( # type: ignore
251241
logger.info("Best B-score: %s", best_B)
252242
if isinstance(self.substitute_classifier, PyTorchClassifier):
253243
x_train[self.indices_target[best_indices_poison]] = best_x_poisoned
254-
elif isinstance(self.substitute_classifier, TensorFlowV2Classifier):
255-
x_train[self.indices_target[best_indices_poison]] = best_x_poisoned
256244
else:
257-
raise NotImplementedError("SleeperAgentAttack is currently implemented only for PyTorch and TensorFlowV2.")
245+
raise NotImplementedError("SleeperAgentAttack is currently implemented only for PyTorch.")
258246
return x_train, y_train
259247

260248
def _select_target_train_samples(self, x_train: np.ndarray, y_train: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
@@ -294,9 +282,7 @@ def _model_retraining(
294282
:param x_test: clean test data.
295283
:param y_test: labels for test data.
296284
"""
297-
if isinstance(
298-
self.substitute_classifier.preprocessing, (StandardisationMeanStdPyTorch, StandardisationMeanStdTensorFlow)
299-
):
285+
if isinstance(self.substitute_classifier.preprocessing, StandardisationMeanStdPyTorch):
300286
x_train_un = np.copy(x_train)
301287
x_train_un[self.indices_target[self.indices_poison]] = poisoned_samples
302288
x_train_un = x_train_un * self.substitute_classifier.preprocessing.std
@@ -315,22 +301,8 @@ def _model_retraining(
315301
self.substitute_classifier = model_pt
316302
self.substitute_classifier.model.training = check_train
317303

318-
elif isinstance(self.substitute_classifier, TensorFlowV2Classifier):
319-
check_train = self.substitute_classifier.model.trainable
320-
model_tf = self._create_model(
321-
x_train_un,
322-
y_train,
323-
x_test,
324-
y_test,
325-
batch_size=self.retrain_batch_size,
326-
epochs=self.model_retraining_epoch,
327-
)
328-
329-
self.substitute_classifier = model_tf
330-
self.substitute_classifier.model.trainable = check_train
331-
332304
else:
333-
raise NotImplementedError("SleeperAgentAttack is currently implemented only for PyTorch and TensorFlowV2.")
305+
raise NotImplementedError("SleeperAgentAttack is currently implemented only for PyTorch.")
334306

335307
def _create_model(
336308
self,
@@ -340,7 +312,7 @@ def _create_model(
340312
y_test: np.ndarray,
341313
batch_size: int = 128,
342314
epochs: int = 80,
343-
) -> "TensorFlowV2Classifier" | "PyTorchClassifier":
315+
) -> "PyTorchClassifier":
344316
"""
345317
Creates a new model.
346318
@@ -365,17 +337,7 @@ def _create_model(
365337
logger.info("Accuracy of retrained model : %s", accuracy * 100.0)
366338
return model_pt
367339

368-
if isinstance(self.substitute_classifier, TensorFlowV2Classifier):
369-
370-
self.substitute_classifier.model.trainable = True
371-
model_tf = self.substitute_classifier.clone_for_refitting()
372-
model_tf.fit(x_train, y_train, batch_size=batch_size, nb_epochs=epochs, verbose=False)
373-
predictions = model_tf.predict(x_test)
374-
accuracy = np.sum(np.argmax(predictions, axis=1) == np.argmax(y_test, axis=1)) / len(y_test)
375-
logger.info("Accuracy of retrained model : %s", accuracy * 100.0)
376-
return model_tf
377-
378-
raise ValueError("SleeperAgentAttack is currently implemented only for PyTorch and TensorFlowV2.")
340+
raise ValueError("SleeperAgentAttack is currently implemented only for PyTorch.")
379341

380342
# This function is responsible for returning indices of poison images with maximum gradient norm
381343
def _select_poison_indices(
@@ -408,28 +370,8 @@ def _select_poison_indices(
408370
for grad in gradients:
409371
grad_norm += grad.detach().pow(2).sum()
410372
grad_norms.append(grad_norm.sqrt())
411-
elif isinstance(self.substitute_classifier, TensorFlowV2Classifier):
412-
import tensorflow as tf
413-
414-
model_trainable = classifier.model.trainable
415-
classifier.model.trainable = False
416-
grad_norms = []
417-
for i in range(len(x_samples) - 1):
418-
image = tf.constant(x_samples[i : i + 1])
419-
label = tf.constant(y_samples[i : i + 1])
420-
with tf.GradientTape() as t: # pylint: disable=invalid-name
421-
t.watch(classifier.model.weights)
422-
output = classifier.model(image, training=False)
423-
loss_tf = classifier.loss_object(label, output) # type: ignore
424-
gradients = list(t.gradient(loss_tf, classifier.model.weights))
425-
gradients = [w for w in gradients if w is not None]
426-
grad_norm = tf.constant(0, dtype=tf.float32)
427-
for grad in gradients:
428-
grad_norm += tf.reduce_sum(tf.math.square(grad))
429-
grad_norms.append(tf.math.sqrt(grad_norm))
430-
classifier.model.trainable = model_trainable
431373
else:
432-
raise NotImplementedError("SleeperAgentAttack is currently implemented only for PyTorch and TensorFlowV2.")
374+
raise NotImplementedError("SleeperAgentAttack is currently implemented only for PyTorch.")
433375
indices = sorted(range(len(grad_norms)), key=lambda k: grad_norms[k]) # type: ignore
434376
indices = indices[-num_poison:]
435377
return np.array(indices) # this will get only indices for target class

art/estimators/poison_mitigation/neural_cleanse/keras.py

Lines changed: 72 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,10 @@ def __init__(
122122
:param cost_multiplier: How much to change the cost in the Neural Cleanse optimization
123123
:param batch_size: The batch size for optimizations in the Neural Cleanse optimization
124124
"""
125+
import tensorflow as tf
126+
from tensorflow.keras.layers import Lambda
125127
import keras.backend as K
128+
from keras.optimizers import Adam
126129
from keras.losses import categorical_crossentropy
127130
from keras.metrics import categorical_accuracy
128131

@@ -153,50 +156,66 @@ def __init__(
153156
self.epsilon = K.epsilon()
154157

155158
# Normalize mask between [0, 1]
156-
self.mask_tensor_raw = K.variable(mask)
157-
# self.mask_tensor = K.expand_dims(K.tanh(self.mask_tensor_raw) / (2 - self.epsilon) + 0.5, axis=0)
158-
self.mask_tensor = K.tanh(self.mask_tensor_raw) / (2 - self.epsilon) + 0.5
159+
self.mask_tensor_raw = tf.Variable(mask, dtype=tf.float32)
160+
# self.mask_tensor = tf.math.tanh(self.mask_tensor_raw) / (2.0 - self.epsilon) + 0.5
159161

160162
# Normalize pattern between [0, 1]
161-
self.pattern_tensor_raw = K.variable(pattern)
162-
self.pattern_tensor = K.expand_dims(K.tanh(self.pattern_tensor_raw) / (2 - self.epsilon) + 0.5, axis=0)
163+
self.pattern_tensor_raw = tf.Variable(pattern, dtype=tf.float32)
164+
# self.pattern_tensor = tf.expand_dims(tf.math.tanh(self.pattern_tensor_raw) / (2 - self.epsilon) + 0.5, axis=0)
163165

164-
reverse_mask_tensor = K.ones_like(self.mask_tensor) - self.mask_tensor
165-
input_tensor = K.placeholder(model.input_shape)
166-
x_adv_tensor = reverse_mask_tensor * input_tensor + self.mask_tensor * self.pattern_tensor
166+
# @tf.function
167+
def train_step(x_batch, y_batch):
168+
with tf.GradientTape() as tape:
169+
# Normalize mask and pattern
170+
self.mask_tensor = tf.tanh(self.mask_tensor_raw) / (2 - self.epsilon) + 0.5
171+
self.pattern_tensor = tf.tanh(self.pattern_tensor_raw) / (2 - self.epsilon) + 0.5
167172

168-
output_tensor = self.model(x_adv_tensor)
169-
y_true_tensor = K.placeholder(model.outputs[0].shape.as_list())
173+
# Construct adversarial example
174+
reverse_mask_tensor = 1.0 - self.mask_tensor
175+
x_adv = reverse_mask_tensor * x_batch + self.mask_tensor * self.pattern_tensor
170176

171-
self.loss_acc = categorical_accuracy(output_tensor, y_true_tensor)
172-
self.loss_ce = categorical_crossentropy(output_tensor, y_true_tensor)
177+
# Forward pass
178+
y_pred = self.model(x_adv, training=False)
173179

174-
if self.norm == 1:
175-
# TODO: change 3 to dynamically set img_color
176-
self.loss_reg = K.sum(K.abs(self.mask_tensor)) / 3
177-
elif self.norm == 2:
178-
self.loss_reg = K.sqrt(K.sum(K.square(self.mask_tensor)) / 3)
180+
# Classification loss
181+
loss_ce = tf.keras.losses.categorical_crossentropy(y_batch, y_pred, from_logits=self.use_logits)
179182

180-
self.cost = self.init_cost
181-
self.cost_tensor = K.variable(self.cost)
182-
self.loss_combined = self.loss_ce + self.loss_reg * self.cost_tensor
183+
# Accuracy
184+
correct = tf.equal(tf.argmax(y_pred, axis=1), tf.argmax(y_batch, axis=1))
185+
loss_acc = tf.reduce_mean(tf.cast(correct, tf.float32))
183186

184-
try:
185-
from keras.optimizers import Adam
187+
# Regularization loss
188+
if self.norm == 1:
189+
loss_reg = tf.reduce_sum(tf.abs(self.mask_tensor)) / tf.cast(
190+
tf.shape(self.mask_tensor)[-1], tf.float32
191+
)
192+
elif self.norm == 2:
193+
loss_reg = tf.sqrt(
194+
tf.reduce_sum(tf.square(self.mask_tensor)) / tf.cast(tf.shape(self.mask_tensor)[-1], tf.float32)
195+
)
196+
else:
197+
raise ValueError(f"Unsupported norm {self.norm}")
186198

187-
self.opt = Adam(lr=self.learning_rate, beta_1=0.5, beta_2=0.9)
188-
except ImportError:
189-
from keras.optimizers import adam_v2
199+
# Total loss
200+
loss_combined = tf.reduce_mean(loss_ce) + self.cost * loss_reg
190201

191-
self.opt = adam_v2.Adam(lr=self.learning_rate, beta_1=0.5, beta_2=0.9)
192-
self.updates = self.opt.get_updates(
193-
params=[self.pattern_tensor_raw, self.mask_tensor_raw], loss=self.loss_combined
194-
)
195-
self.train = K.function(
196-
[input_tensor, y_true_tensor],
197-
[self.loss_ce, self.loss_reg, self.loss_combined, self.loss_acc],
198-
updates=self.updates,
199-
)
202+
# Compute gradients
203+
grads = tape.gradient(loss_combined, [self.mask_tensor_raw, self.pattern_tensor_raw])
204+
205+
# Apply gradients
206+
self.opt.apply_gradients(zip(grads, [self.mask_tensor_raw, self.pattern_tensor_raw]))
207+
208+
print(loss_acc)
209+
210+
return loss_ce, loss_reg, loss_combined, loss_acc
211+
212+
self.train = train_step
213+
214+
# Initialize cost (as a TensorFlow variable so it can be updated during training)
215+
self.cost = self.init_cost
216+
self.cost_tensor = tf.Variable(self.cost, trainable=False, dtype=tf.float32)
217+
218+
self.opt = Adam(learning_rate=self.learning_rate, beta_1=0.5, beta_2=0.9)
200219

201220
@property
202221
def input_shape(self) -> tuple[int, ...]:
@@ -212,13 +231,14 @@ def reset(self):
212231
Reset the state of the defense
213232
:return:
214233
"""
215-
import keras.backend as K
234+
import tensorflow as tf
216235

217236
self.cost = self.init_cost
218-
K.set_value(self.cost_tensor, self.init_cost)
219-
K.set_value(self.opt.iterations, 0)
220-
for weight in self.opt.weights:
221-
K.set_value(weight, np.zeros(K.int_shape(weight)))
237+
self.cost_tensor.assign(self.init_cost)
238+
self.opt.iterations.assign(0)
239+
if self.opt._variables:
240+
for var in self.opt._variables:
241+
var.assign(tf.zeros_like(var))
222242

223243
def generate_backdoor(
224244
self, x_val: np.ndarray, y_val: np.ndarray, y_target: np.ndarray
@@ -227,8 +247,9 @@ def generate_backdoor(
227247
Generates a possible backdoor for the model. Returns the pattern and the mask
228248
:return: A tuple of the pattern and mask for the model.
229249
"""
250+
import tensorflow as tf
230251
import keras.backend as K
231-
from keras.preprocessing.image import ImageDataGenerator
252+
from tensorflow.keras.preprocessing.image import ImageDataGenerator
232253

233254
self.reset()
234255
datagen = ImageDataGenerator()
@@ -249,20 +270,20 @@ def generate_backdoor(
249270
loss_acc_list = []
250271

251272
for _ in range(mini_batch_size):
252-
x_batch, _ = gen.next()
273+
x_batch, _ = next(gen)
253274
y_batch = [y_target] * x_batch.shape[0]
254-
_, batch_loss_reg, _, batch_loss_acc = self.train([x_batch, y_batch])
275+
_, batch_loss_reg, _, batch_loss_acc = self.train(x_batch, y_batch)
255276

256-
loss_reg_list.extend(list(batch_loss_reg.flatten()))
257-
loss_acc_list.extend(list(batch_loss_acc.flatten()))
277+
loss_reg_list.extend(list(tf.reshape(batch_loss_reg, [-1]).numpy()))
278+
loss_acc_list.extend(list(tf.reshape(batch_loss_acc, [-1]).numpy()))
258279

259280
avg_loss_reg = np.mean(loss_reg_list)
260281
avg_loss_acc = np.mean(loss_acc_list)
261282

262283
# save best mask/pattern so far
263284
if avg_loss_acc >= self.attack_success_threshold and avg_loss_reg < reg_best:
264-
mask_best = K.eval(self.mask_tensor)
265-
pattern_best = K.eval(self.pattern_tensor)
285+
mask_best = self.mask_tensor.numpy()
286+
pattern_best = self.pattern_tensor.numpy()
266287
reg_best = avg_loss_reg
267288

268289
# check early stop
@@ -283,7 +304,7 @@ def generate_backdoor(
283304
cost_set_counter += 1
284305
if cost_set_counter >= self.patience:
285306
self.cost = self.init_cost
286-
K.set_value(self.cost_tensor, self.cost)
307+
self.cost_tensor.assign(self.cost)
287308
cost_up_counter = 0
288309
cost_down_counter = 0
289310
cost_up_flag = False
@@ -301,17 +322,17 @@ def generate_backdoor(
301322
if cost_up_counter >= self.patience:
302323
cost_up_counter = 0
303324
self.cost *= self.cost_multiplier_up
304-
K.set_value(self.cost_tensor, self.cost)
325+
self.cost_tensor.assign(self.cost)
305326
cost_up_flag = True
306327
elif cost_down_counter >= self.patience:
307328
cost_down_counter = 0
308329
self.cost /= self.cost_multiplier_down
309-
K.set_value(self.cost_tensor, self.cost)
330+
self.cost_tensor.assign(self.cost)
310331
cost_down_flag = True
311332

312333
if mask_best is None:
313-
mask_best = K.eval(self.mask_tensor)
314-
pattern_best = K.eval(self.pattern_tensor)
334+
mask_best = self.mask_tensor.numpy()
335+
pattern_best = self.pattern_tensor.numpy()
315336

316337
if pattern_best is None:
317338
raise ValueError("Unexpected `None` detected.")

tests/attacks/poison/test_sleeper_agent_attack.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
logger = logging.getLogger(__name__)
2929

3030

31-
@pytest.mark.only_with_platform("pytorch", "tensorflow2")
31+
@pytest.mark.only_with_platform("pytorch")
3232
def test_poison(art_warning, get_default_mnist_subset, image_dl_estimator, framework):
3333
try:
3434
(x_train, y_train), (x_test, y_test) = get_default_mnist_subset
@@ -85,7 +85,7 @@ def test_poison(art_warning, get_default_mnist_subset, image_dl_estimator, frame
8585
art_warning(e)
8686

8787

88-
@pytest.mark.only_with_platform("pytorch", "tensorflow2")
88+
@pytest.mark.only_with_platform("pytorch")
8989
def test_check_params(art_warning, get_default_mnist_subset, image_dl_estimator):
9090
try:
9191
classifier, _ = image_dl_estimator(functional=True)

0 commit comments

Comments
 (0)