1- import numpy as np
21import torch
3- import torch .nn .functional as F
4- import matplotlib .pyplot as plt
52import logging
3+ import os
4+ from datetime import datetime
5+ from torch .nn import Softmax
6+ from .utils import fgsm_attack , pgd_attack , compute_success_rate , log_metrics , visualize_adversarial_examples
7+ from .utils import seed_everything
68
7- # Set up logging
8- logging .basicConfig (level = logging .INFO )
9-
10- def fgsm_attack (input_batch_data : torch .Tensor , model : torch .nn .Module , input_shape : tuple , epsilon : float ) -> torch .Tensor :
11- """
12- Apply the FGSM attack to input images given a pre-trained PyTorch model.
13-
14- Args:
15- input_batch_data (torch.Tensor): Batch of input images.
16- model (torch.nn.Module): Pre-trained PyTorch model to be attacked.
17- input_shape (tuple): Shape of the input array.
18- epsilon (float): Magnitude of the perturbation for the attack.
19-
20- Returns:
21- torch.Tensor: Adversarial images generated by the FGSM attack.
22- """
23- model .eval ()
24- device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
25- input_batch_data = input_batch_data .to (device ).detach ().requires_grad_ (True )
26-
27- adversarial_batch_data = torch .clone (input_batch_data ).detach ()
28-
29- for img in input_batch_data :
30- preds = model (img .reshape (input_shape ))
31- target = torch .argmax (preds )
32- loss = F .cross_entropy (preds , target .unsqueeze (0 ))
33-
34- model .zero_grad ()
35- loss .backward ()
36-
37- adversarial_img = img + epsilon * img .grad .sign ()
38- adversarial_img = torch .clamp (adversarial_img , 0 , 1 )
39- adversarial_batch_data .append (adversarial_img )
40-
41- return adversarial_batch_data
42-
43- def compute_gradients (model , img , target_class ):
44- preds = model (img )
45- target_score = preds [0 , target_class ]
46- return torch .autograd .grad (target_score , img )[0 ]
47-
48- def generate_adversarial_examples (input_batch_data , model , method = 'fgsm' , ** kwargs ):
49- if method == 'fgsm' :
50- return fgsm_attack (input_batch_data , model , ** kwargs )
51- # Implement other attack methods as needed
52-
53- def visualize_adversarial_examples (original , adversarial ):
54- # Code to visualize original vs adversarial images
55- pass
56-
57- def log_metrics (success_rate , average_perturbation ):
58- logging .info (f'Success Rate: { success_rate } , Average Perturbation: { average_perturbation } ' )
59-
60- class Config :
61- def __init__ (self , epsilon = 0.1 , attack_method = 'fgsm' ):
9+ class AdversarialTester :
10+ def __init__ (self , model : torch .nn .Module , epsilon : float = 0.1 , attack_method : str = 'fgsm' , alpha : float = 0.01 ,
11+ num_steps : int = 40 , device = None , save_dir : str = './results' , seed : int = 42 ):
12+ seed_everything (seed )
13+ self .model = model
6214 self .epsilon = epsilon
6315 self .attack_method = attack_method
16+ self .alpha = alpha
17+ self .num_steps = num_steps
18+ self .device = device or ("cuda" if torch .cuda .is_available () else "cpu" )
19+ self .save_dir = save_dir
20+
21+ # Create save directory if it doesn't exist
22+ os .makedirs (self .save_dir , exist_ok = True )
23+ self .model .to (self .device )
24+ self .model .eval ()
25+
26+ self ._setup_logging ()
27+
28+ def _setup_logging (self ):
29+ log_file = os .path .join (self .save_dir , f"attack_log_{ datetime .now ().strftime ('%Y%m%d_%H%M%S' )} .log" )
30+ logging .basicConfig (filename = log_file , level = logging .DEBUG )
31+ logging .info (f"Started adversarial testing at { datetime .now ().strftime ('%Y-%m-%d %H:%M:%S' )} " )
32+ logging .info (f"Using model: { self .model .__class__ .__name__ } " )
33+ logging .info (f"Attack Method: { self .attack_method } , Epsilon: { self .epsilon } , Alpha: { self .alpha } , Steps: { self .num_steps } " )
34+
35+ def test_attack (self , input_batch_data : torch .Tensor ):
36+ input_batch_data = input_batch_data .to (self .device )
37+ adversarial_images = self ._generate_adversarial_images (input_batch_data )
38+
39+ # Save and log images
40+ self ._save_images (input_batch_data , adversarial_images )
41+ self ._compute_and_log_metrics (input_batch_data , adversarial_images )
42+
43+ def _generate_adversarial_images (self , input_batch_data : torch .Tensor ):
44+ logging .info (f"Starting attack with method: { self .attack_method } " )
45+ if self .attack_method == 'fgsm' :
46+ return fgsm_attack (input_batch_data , self .model , self .epsilon , self .device )
47+ elif self .attack_method == 'pgd' :
48+ return pgd_attack (input_batch_data , self .model , self .epsilon , self .alpha , self .num_steps , self .device )
49+ else :
50+ raise ValueError (f"Unsupported attack method: { self .attack_method } " )
51+
52+ def _save_images (self , original_images : torch .Tensor , adversarial_images : torch .Tensor ):
53+ for i in range (original_images .size (0 )):
54+ original_image_path = os .path .join (self .save_dir , f"original_{ i } .png" )
55+ adversarial_image_path = os .path .join (self .save_dir , f"adversarial_{ i } .png" )
56+ visualize_adversarial_examples (original_images , adversarial_images , original_image_path , adversarial_image_path )
57+
58+ def _compute_and_log_metrics (self , original_images : torch .Tensor , adversarial_images : torch .Tensor ):
59+ original_predictions = torch .argmax (self .model (original_images ), dim = 1 )
60+ adversarial_predictions = torch .argmax (self .model (adversarial_images ), dim = 1 )
61+
62+ success_rate = compute_success_rate (original_predictions , adversarial_predictions )
63+ average_perturbation = torch .mean (torch .abs (adversarial_images - original_images )).item ()
64+
65+ log_metrics (success_rate , average_perturbation )
66+ self ._save_metrics (success_rate , average_perturbation )
67+
68+ logging .info (f"Success Rate: { success_rate :.4f} , Average Perturbation: { average_perturbation :.4f} " )
69+
70+ def _save_metrics (self , success_rate : float , avg_perturbation : float ):
71+ """
72+ Save the metrics (success rate and average perturbation) to a file.
73+ """
74+ metrics_file = os .path .join (self .save_dir , "attack_metrics.txt" )
75+ with open (metrics_file , 'a' ) as f :
76+ f .write (f"Success Rate: { success_rate :.4f} , Average Perturbation: { avg_perturbation :.4f} \n " )
0 commit comments