|
1 | | -from Adversarial_Observation.utils import load_pretrained_model, load_data, fgsm_attack, pgd_attack # Assuming utils.py contains this function |
2 | | -from Adversarial_Observation import AdversarialTester |
3 | | -from Adversarial_Observation import ParticleSwarm |
4 | 1 | import torch |
| 2 | +from tqdm import tqdm |
| 3 | +import time |
| 4 | +from torch.utils.data import DataLoader, TensorDataset |
| 5 | +from Adversarial_Observation.utils import load_MNIST_model, load_data |
| 6 | +from Adversarial_Observation import AdversarialTester, ParticleSwarm |
5 | 7 |
|
6 | | -def adversarial_attack_whitebox(model, dataloader): |
| 8 | + |
| 9 | +def adversarial_attack_whitebox(model: torch.nn.Module, dataloader: DataLoader) -> None: |
| 10 | + """ |
| 11 | + Performs a white-box adversarial attack on the model using AdversarialTester. |
| 12 | + |
| 13 | + Args: |
| 14 | + model (torch.nn.Module): The trained model to attack. |
| 15 | + dataloader (DataLoader): The data loader containing the dataset. |
| 16 | + """ |
7 | 17 | # Initialize the AdversarialTester with the model |
8 | 18 | attacker = AdversarialTester(model) |
9 | 19 |
|
10 | 20 | # Perform the attack on the dataset |
11 | 21 | for images, _ in dataloader: |
12 | 22 | attacker.test_attack(images) |
13 | 23 |
|
14 | | -# Example function call |
15 | | -def adversarial_attack_blackbox(model, dataloader): |
16 | | - single_image_input = dataloader.dataset[0][0] # Get the first image from the dataset |
17 | | - single_image_target = torch.argmax(model(single_image_input.unsqueeze(0))) # Get the target label for the first image |
18 | 24 |
|
19 | | - single_misclassification_input = dataloader.dataset[1][0] # Get the second image from the dataset |
20 | | - single_misclassification_target = torch.argmax(model(single_misclassification_input.unsqueeze(0))) # Get the target label for the second image |
| 25 | +def adversarial_attack_blackbox(model: torch.nn.Module, dataloader: DataLoader) -> DataLoader: |
| 26 | + """ |
| 27 | + Performs a black-box adversarial attack on the model using Particle Swarm optimization. |
| 28 | + |
| 29 | + Args: |
| 30 | + model (torch.nn.Module): The trained model to attack. |
| 31 | + dataloader (DataLoader): The data loader containing the dataset. |
| 32 | + |
| 33 | + Returns: |
| 34 | + DataLoader: A dataloader containing adversarially perturbed images. |
| 35 | + """ |
| 36 | + # Get the first two images from the dataset to simulate misclassification |
| 37 | + single_image_input = dataloader.dataset[0][0] |
| 38 | + single_image_target = torch.argmax(model(single_image_input.unsqueeze(0))) |
| 39 | + |
| 40 | + single_misclassification_input = dataloader.dataset[1][0] |
| 41 | + single_misclassification_target = torch.argmax(model(single_misclassification_input.unsqueeze(0))) |
| 42 | + |
| 43 | + # Ensure the targets are different to simulate misclassification |
| 44 | + assert single_image_target != single_misclassification_target, \ |
| 45 | + "Target classes should be different for misclassification." |
21 | 46 |
|
22 | | - input_set = [single_image_input + torch.randn_like(single_image_input) for _ in range(100)] # Create a set of 10 noisy images |
23 | | - # convert input_set to a tensor |
| 47 | + # Create a noisy input set for black-box attack |
| 48 | + input_set = [single_image_input + torch.randn_like(single_image_input) for _ in range(100)] |
24 | 49 | input_set = torch.stack(input_set) |
25 | 50 |
|
26 | | - assert single_image_target != single_misclassification_target, "Target classes should be different for misclassification." |
27 | | - print(f"Target class for single image: {single_image_target}") |
28 | | - print(f"Target class for misclassification image: {single_misclassification_target} with confidence {torch.max(torch.softmax(model(single_misclassification_input.unsqueeze(0)), dim=1))}") |
29 | | - |
30 | | - # Initialize the Particle Swarm optimizer with the model and the input set |
31 | | - attacker = ParticleSwarm(model, |
32 | | - input_set, |
33 | | - single_misclassification_target, |
34 | | - num_iterations=30, |
35 | | - epsilon=0.8, |
36 | | - save_dir='results', |
37 | | - inertia_weight=0.8, |
38 | | - cognitive_weight=0.5, |
39 | | - social_weight=0.5, |
40 | | - momentum=0.9, |
41 | | - velocity_clamp=0.1) |
42 | | - final_perturbed_images = attacker.optimize() |
43 | | - import pdb; pdb.set_trace() |
44 | | - return final_perturbed_images |
45 | | - |
46 | | - |
47 | | -def main(): |
48 | | - # Load pre-trained model (ResNet18) |
49 | | - model = load_pretrained_model() |
50 | | - |
51 | | - # Load CIFAR-10 validation data (using the transformed dataset) |
52 | | - dataloader = load_data(batch_size=32) |
53 | | - |
54 | | - # Perform white-box attack using AdversarialTester |
55 | | - # print("Performing white-box adversarial attack...") |
56 | | - # adversarial_attack_whitebox(model, dataloader) |
57 | | - |
58 | | - # Perform black-box attack using Swarm |
| 51 | + print(f"Target class for original image: {single_image_target}") |
| 52 | + print(f"Target class for misclassified image: {single_misclassification_target}") |
| 53 | + |
| 54 | + # Initialize the Particle Swarm optimizer with the model and input set |
| 55 | + attacker = ParticleSwarm( |
| 56 | + model, input_set, single_misclassification_target, num_iterations=30, |
| 57 | + epsilon=0.8, save_dir='results', inertia_weight=0.8, cognitive_weight=0.5, |
| 58 | + social_weight=0.5, momentum=0.9, velocity_clamp=0.1 |
| 59 | + ) |
| 60 | + attacker.optimize() |
| 61 | + |
| 62 | + # Generate adversarial dataset |
| 63 | + return get_adversarial_dataloader(attacker, model, single_misclassification_target, single_image_target) |
| 64 | + |
| 65 | + |
| 66 | +def get_adversarial_dataloader(attacker: ParticleSwarm, model: torch.nn.Module, target_class: int, original_class: int) -> DataLoader: |
| 67 | + """ |
| 68 | + Generates a DataLoader containing adversarially perturbed images. |
| 69 | + |
| 70 | + Args: |
| 71 | + attacker (ParticleSwarm): The ParticleSwarm instance after optimization. |
| 72 | + model (torch.nn.Module): The trained model used for evaluating adversarial examples. |
| 73 | + target_class (int): The target class for the attack. |
| 74 | + original_class (int): The original class of the image. |
| 75 | + |
| 76 | + Returns: |
| 77 | + DataLoader: A dataset containing adversarial images with their target and original class confidences. |
| 78 | + """ |
| 79 | + print(f"Generating adversarial examples with target class {target_class} and original class {original_class}") |
| 80 | + |
| 81 | + images, target_confidence, original_confidence = [], [], [] |
| 82 | + |
| 83 | + for particle in attacker.particles: |
| 84 | + for position in particle.history: |
| 85 | + output = model(position) |
| 86 | + if torch.argmax(output) == target_class: |
| 87 | + images.append(position) |
| 88 | + target_confidence.append(torch.softmax(output, dim=1)[target_class]) |
| 89 | + original_confidence.append(torch.softmax(model(particle.original_data))[original_class]) |
| 90 | + |
| 91 | + # Convert lists to tensors and return a TensorDataset |
| 92 | + X_images = torch.stack(images) |
| 93 | + X_original_confidence = torch.stack(original_confidence) |
| 94 | + y = torch.stack(target_confidence) |
| 95 | + |
| 96 | + return DataLoader(TensorDataset(X_images, y, X_original_confidence)) |
| 97 | + |
| 98 | + |
| 99 | +def train(model: torch.nn.Module, dataloader: DataLoader, epochs: int = 10) -> torch.nn.Module: |
| 100 | + """ |
| 101 | + Trains the model for a specified number of epochs. |
| 102 | + |
| 103 | + Args: |
| 104 | + model (torch.nn.Module): The model to train. |
| 105 | + dataloader (DataLoader): The data loader for the training data. |
| 106 | + epochs (int, optional): Number of training epochs. Defaults to 10. |
| 107 | + |
| 108 | + Returns: |
| 109 | + torch.nn.Module: The trained model. |
| 110 | + """ |
| 111 | + loss_fn = torch.nn.CrossEntropyLoss() |
| 112 | + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) |
| 113 | + |
| 114 | + for epoch in range(epochs): |
| 115 | + start_time = time.time() # Track time for each epoch |
| 116 | + print(f"\nEpoch {epoch+1}/{epochs}:") |
| 117 | + |
| 118 | + running_loss = 0.0 |
| 119 | + accuracy = 0 |
| 120 | + |
| 121 | + # Use tqdm for a progress bar |
| 122 | + with tqdm(dataloader, desc="Training", unit="batch") as pbar: |
| 123 | + for images, labels in pbar: |
| 124 | + optimizer.zero_grad() |
| 125 | + |
| 126 | + # Forward pass |
| 127 | + output = model(images) |
| 128 | + |
| 129 | + # Compute loss |
| 130 | + loss_val = loss_fn(output, labels) |
| 131 | + |
| 132 | + # Backward pass and optimization |
| 133 | + loss_val.backward() |
| 134 | + optimizer.step() |
| 135 | + |
| 136 | + running_loss += loss_val.item() |
| 137 | + accuracy += (output.argmax(dim=1) == labels).float().mean().item() |
| 138 | + |
| 139 | + # Update progress bar description |
| 140 | + pbar.set_postfix(loss=running_loss / (pbar.n + 1), accuracy=accuracy / (pbar.n + 1)) |
| 141 | + |
| 142 | + # Print average loss and accuracy for the epoch |
| 143 | + epoch_loss = running_loss / len(dataloader) |
| 144 | + elapsed_time = time.time() - start_time |
| 145 | + print(f"Epoch {epoch+1} completed in {elapsed_time:.2f}s, Average Loss: {epoch_loss:.4f}, Accuracy: {accuracy / len(dataloader):.4f}") |
| 146 | + |
| 147 | + return model |
| 148 | + |
| 149 | + |
| 150 | +def main() -> None: |
| 151 | + """ |
| 152 | + Main function to execute the adversarial attack workflow. |
| 153 | + """ |
| 154 | + # Load pre-trained model (MNIST model) |
| 155 | + model = load_MNIST_model() |
| 156 | + |
| 157 | + # Load MNIST dataset (train and test loaders) |
| 158 | + train_loader, test_loader = load_data() |
| 159 | + |
| 160 | + # Train the model |
| 161 | + model = train(model, train_loader, epochs=3) |
| 162 | + |
| 163 | + # Perform black-box attack using Particle Swarm optimization |
59 | 164 | print("Performing black-box adversarial attack...") |
60 | | - adversarial_attack_blackbox(model, dataloader) |
| 165 | + final_dataloader = adversarial_attack_blackbox(model, test_loader) |
| 166 | + |
| 167 | + |
61 | 168 |
|
62 | 169 | if __name__ == "__main__": |
63 | 170 | main() |
0 commit comments