|
| 1 | + |
| 2 | +# Imports |
| 3 | +!pip install adversarial-robustness-toolbox |
| 4 | +import torch |
| 5 | +import torch.nn as nn |
| 6 | +import numpy as np |
| 7 | +from torch.utils.data import Subset |
| 8 | +from torchvision import datasets, transforms, models |
| 9 | +from art.estimators.classification import PyTorchClassifier |
| 10 | +from art.utils import to_categorical |
| 11 | +from art.attacks.poisoning import PoisoningAttackBackdoor |
| 12 | + |
| 13 | +# Trigger Generator:A small CNN that learns to generate input-specific triggers |
| 14 | +class TriggerGenerator(nn.Module): |
| 15 | + def __init__(self, input_channels=3): |
| 16 | + super().__init__() |
| 17 | + self.net = nn.Sequential( |
| 18 | + nn.Conv2d(input_channels, 32, kernel_size=3, padding=1), |
| 19 | + nn.ReLU(), |
| 20 | + nn.Conv2d(32, 32, kernel_size=3, padding=1), |
| 21 | + nn.ReLU(), |
| 22 | + nn.Conv2d(32, input_channels, kernel_size=3, padding=1), |
| 23 | + nn.Tanh() |
| 24 | + ) |
| 25 | + |
| 26 | + def forward(self, x): |
| 27 | + return self.net(x) |
| 28 | + |
| 29 | +# Custom Poisoning Attack: DynamicBackdoorGAN-This class defines how to poison data using the GAN trigger generator |
| 30 | +class DynamicBackdoorGAN(PoisoningAttackBackdoor): |
| 31 | + def __init__(self, generator, target_label, backdoor_rate, classifier, epsilon=0.5): |
| 32 | + super().__init__(perturbation=lambda x: x) |
| 33 | + self.classifier = classifier |
| 34 | + self.generator = generator.to(classifier.device) |
| 35 | + self.target_label = target_label |
| 36 | + self.backdoor_rate = backdoor_rate |
| 37 | + self.epsilon = epsilon |
| 38 | +# Add trigger to a given image batch |
| 39 | + def apply_trigger(self, images): |
| 40 | + self.generator.eval() |
| 41 | + with torch.no_grad(): |
| 42 | + images = nn.functional.interpolate(images, size=(32, 32), mode='bilinear') # Resize images to ensure uniform dimension |
| 43 | + triggers = self.generator(images.to(self.classifier.device)) #Generate dynamic, input-specific triggers using the trained TriggerGenerator |
| 44 | + poisoned = (images.to(self.classifier.device) + self.epsilon * triggers).clamp(0, 1) # Clamp the pixel values to ensure they stay in the valid [0, 1] range. |
| 45 | + return poisoned |
| 46 | +# Poison the training data by injecting dynamic triggers and changing labels |
| 47 | + def poison(self, x, y): |
| 48 | + # Convert raw image data (x) to torch tensors (float), and convert one-hot labels (y) to class indices-required by ART |
| 49 | + x_tensor = torch.tensor(x).float() |
| 50 | + y_tensor = torch.tensor(np.argmax(y, axis=1)) |
| 51 | + # Calculate total number of samples and how many should be poisoned(posion ratio=backdoor_rate) |
| 52 | + batch_size = x_tensor.shape[0] |
| 53 | + n_poison = int(self.backdoor_rate * batch_size) |
| 54 | + # Apply the learned trigger to the first 'n_poison' samples |
| 55 | + poisoned = self.apply_trigger(x_tensor[:n_poison]) |
| 56 | + # The remaining samples remain clean |
| 57 | + clean = x_tensor[n_poison:].to(self.classifier.device) |
| 58 | + # Combine poisoned and clean samples into a single batch |
| 59 | + poisoned_images = torch.cat([poisoned, clean], dim=0).cpu().numpy() |
| 60 | + # Modify the labels of poisoned samples to the attacker's target class |
| 61 | + new_labels = y_tensor.clone() |
| 62 | + new_labels[:n_poison] = self.target_label # Set the poisoned labels to the desired misclassification |
| 63 | + # Convert all labels back to one-hot encoding (required by ART classifiers) |
| 64 | + new_labels = to_categorical(new_labels.numpy(), nb_classes=self.classifier.nb_classes) |
| 65 | + return poisoned_images.astype(np.float32), new_labels.astype(np.float32) |
| 66 | +#Evaluate the attack's success on test data |
| 67 | + def evaluate(self, x_clean, y_clean): |
| 68 | + x_tensor = torch.tensor(x_clean).float() |
| 69 | + poisoned_test = self.apply_trigger(x_tensor).cpu().numpy().astype(np.float32)# Apply the trigger to every test image to create a poisoned test set |
| 70 | + |
| 71 | + preds = self.classifier.predict(poisoned_test) |
| 72 | + true_target = np.full((len(preds),), self.target_label) |
| 73 | + pred_labels = np.argmax(preds, axis=1) |
| 74 | + |
| 75 | + success = np.sum(pred_labels == true_target) |
| 76 | + asr = 100.0 * success / len(pred_labels) |
| 77 | + return asr |
0 commit comments