11import os
22import logging
3- import torch
4- from torch import nn
5- from torchvision import utils
63from typing import List
74from Adversarial_Observation .BirdParticle import BirdParticle
85
6+ import torch
7+ from torchvision import utils
8+
9+ import tensorflow as tf
10+ import numpy as np
911
1012class ParticleSwarm :
1113 """
@@ -15,16 +17,16 @@ class ParticleSwarm:
1517 to misclassify it into the target class.
1618 """
1719
18- def __init__ (self , model : nn . Module , input_set : torch . Tensor , target_class : int ,
20+ def __init__ (self , model : tf . keras . Model , input_set : np . ndarray , target_class : int ,
1921 num_iterations : int = 20 , epsilon : float = 0.8 , save_dir : str = 'results' ,
2022 inertia_weight : float = 0.5 , cognitive_weight : float = .5 , social_weight : float = .5 ,
2123 momentum : float = 0.9 , velocity_clamp : float = 0.1 ):
2224 """
2325 Initialize the Particle Swarm Optimization (PSO) for adversarial attacks.
2426
2527 Args:
26- model (nn.Module ): The model to attack.
27- input_set (torch.Tensor ): The batch of input images to attack.
28+ model (tf.keras.Model ): The model to attack.
29+ input_set (np.ndarray ): The batch of input images to attack as a NumPy array .
2830 target_class (int): The target class for misclassification.
2931 num_iterations (int): The number of optimization iterations.
3032 epsilon (float): The perturbation bound.
@@ -36,20 +38,20 @@ def __init__(self, model: nn.Module, input_set: torch.Tensor, target_class: int,
3638 velocity_clamp (float): The velocity clamp to limit the velocity.
3739 """
3840 self .model = model
39- self .input_set = input_set # The batch of input images
41+ self .input_set = tf . convert_to_tensor ( input_set , dtype = tf . float32 ) # Convert NumPy array to TensorFlow tensor
4042 self .target_class = target_class # The target class index
4143 self .num_iterations = num_iterations
4244 self .epsilon = epsilon # Perturbation bound
4345 self .save_dir = save_dir # Directory to save perturbed images
4446
4547 self .particles : List [BirdParticle ] = [
46- BirdParticle (model , input_set [i :i + 1 ], target_class , epsilon ,
48+ BirdParticle (model , self . input_set [i :i + 1 ], target_class , epsilon ,
4749 inertia_weight = inertia_weight , cognitive_weight = cognitive_weight ,
4850 social_weight = social_weight , momentum = momentum , velocity_clamp = velocity_clamp )
4951 for i in range (len (input_set ))
5052 ]
5153
52- self .global_best_position = torch .zeros_like (self .input_set [0 ]) # Global best position
54+ self .global_best_position = tf .zeros_like (self .input_set [0 ]) # Global best position
5355 self .global_best_score = - float ('inf' ) # Initialize with a very low score
5456
5557 self .fitness_history : List [float ] = [] # History of fitness scores to track progress
@@ -101,32 +103,35 @@ def log_progress(self, iteration: int):
101103 self .logger .info (f"\n { '-' * 60 } " )
102104 self .logger .info (f"Iteration { iteration + 1 } /{ self .num_iterations } " )
103105 self .logger .info (f"{ '=' * 60 } " )
104-
106+
105107 # Table header
106108 header = f"{ 'Particle' :<10} { 'Original Pred' :<15} { 'Perturbed Pred' :<18} { 'Orig Target Prob' :<20} " \
107109 f"{ 'Pert Target Prob' :<20} { 'Personal Best' :<20} { 'Global Best' :<20} "
108110 self .logger .info (header )
109111 self .logger .info (f"{ '-' * 60 } " )
110-
112+
111113 # Log particle information
112114 for i , particle in enumerate (self .particles ):
113- with torch .no_grad ():
114- original_output = self .model (particle .original_data )
115- perturbed_output = self .model (particle .position )
116-
117- original_pred = torch .argmax (original_output , dim = 1 ).item ()
118- perturbed_pred = torch .argmax (perturbed_output , dim = 1 ).item ()
119-
120- original_probs = torch .softmax (original_output , dim = 1 )
121- perturbed_probs = torch .softmax (perturbed_output , dim = 1 )
122-
123- original_prob_target = original_probs [0 , self .target_class ].item ()
124- perturbed_prob_target = perturbed_probs [0 , self .target_class ].item ()
125-
115+ # Get original and perturbed outputs
116+ original_output = self .model (particle .original_data ) # Pass through the model
117+ perturbed_output = self .model (particle .position ) # Pass through the model
118+
119+ # Get predicted classes
120+ original_pred = tf .argmax (original_output , axis = 1 ).numpy ().item ()
121+ perturbed_pred = tf .argmax (perturbed_output , axis = 1 ).numpy ().item ()
122+
123+ # Get softmax probabilities
124+ original_probs = tf .nn .softmax (original_output , axis = 1 )
125+ perturbed_probs = tf .nn .softmax (perturbed_output , axis = 1 )
126+
127+ # Get target class probabilities
128+ original_prob_target = original_probs [0 , self .target_class ].numpy ().item ()
129+ perturbed_prob_target = perturbed_probs [0 , self .target_class ].numpy ().item ()
130+
126131 # Log each particle's data in a formatted row
127132 self .logger .info (f"{ i + 1 :<10} { original_pred :<15} { perturbed_pred :<18} { original_prob_target :<20.4f} "
128133 f"{ perturbed_prob_target :<20.4f} { particle .best_score :<20.4f} { self .global_best_score :<20.4f} " )
129-
134+
130135 self .logger .info (f"{ '=' * 60 } " )
131136
132137 def save_images (self , iteration : int ):
@@ -140,7 +145,12 @@ def save_images(self, iteration: int):
140145 os .makedirs (iteration_folder , exist_ok = True )
141146
142147 for i , particle in enumerate (self .particles ):
143- utils .save_image (particle .position , os .path .join (iteration_folder , f"perturbed_image_{ i + 1 } .png" ))
148+ # Convert TensorFlow tensor to NumPy array
149+ position_numpy = particle .position .numpy ()
150+ # Remove extra batch dimension (if it exists)
151+ position_numpy = np .squeeze (position_numpy ) # Now shape is (28, 28)
152+ position_numpy = np .expand_dims (position_numpy , axis = - 1 ) # Shape becomes (28, 28, 1)
153+ tf .keras .preprocessing .image .save_img (os .path .join (iteration_folder , f"perturbed_image_{ i + 1 } .png" ), position_numpy )
144154
145155 def optimize (self ):
146156 """
@@ -158,7 +168,7 @@ def optimize(self):
158168 best_particle = max (self .particles , key = lambda p : p .best_score )
159169 if best_particle .best_score > self .global_best_score :
160170 self .global_best_score = best_particle .best_score
161- self .global_best_position = best_particle . best_position . clone (). detach ( )
171+ self .global_best_position = tf . identity ( best_particle . best_position )
162172
163173 self .save_images (iteration )
164174
0 commit comments