Skip to content

Commit 9809e51

Browse files
committed
Conversion to tensorflow/keras
1 parent 7f59c03 commit 9809e51

File tree

3 files changed

+294
-53
lines changed

3 files changed

+294
-53
lines changed
Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
import torch
1+
import tensorflow as tf
22
import numpy as np
3-
from torch import nn
4-
53

64
class BirdParticle:
75
"""
@@ -11,33 +9,33 @@ class BirdParticle:
119
fitness evaluation, and the updates to its velocity and position based on the PSO algorithm.
1210
"""
1311

14-
def __init__(self, model: nn.Module, input_data: torch.Tensor, target_class: int, epsilon: float,
15-
velocity: torch.Tensor = None, inertia_weight: float = 0.5,
12+
def __init__(self, model: tf.keras.Model, input_data: tf.Tensor, target_class: int, epsilon: float,
13+
velocity: tf.Tensor = None, inertia_weight: float = 0.5,
1614
cognitive_weight: float = 1.0, social_weight: float = 1.0, momentum: float = 0.9,
1715
velocity_clamp: float = 0.1):
1816
"""
1917
Initialize a particle in the PSO algorithm.
2018
2119
Args:
22-
model (nn.Module): The model to attack.
23-
input_data (torch.Tensor): The input data (image) to attack.
20+
model (tf.keras.Model): The model to attack.
21+
input_data (tf.Tensor): The input data (image) to attack.
2422
target_class (int): The target class for misclassification.
2523
epsilon (float): The perturbation bound (maximum amount the image can be altered).
26-
velocity (torch.Tensor, optional): The initial velocity for the particle's movement. Defaults to zero velocity if not provided.
24+
velocity (tf.Tensor, optional): The initial velocity for the particle's movement. Defaults to zero velocity if not provided.
2725
inertia_weight (float): The inertia weight for the velocity update. Default is 0.5.
2826
cognitive_weight (float): The cognitive weight for the velocity update. Default is 1.0.
2927
social_weight (float): The social weight for the velocity update. Default is 1.0.
3028
momentum (float): The momentum for the velocity update. Default is 0.9.
3129
velocity_clamp (float): The velocity clamp for limiting the maximum velocity. Default is 0.1.
3230
"""
3331
self.model = model
34-
self.original_data = input_data.clone().detach()
32+
self.original_data = tf.identity(input_data) # Clone the input data
3533
self.target_class = target_class
3634
self.epsilon = epsilon
37-
self.best_position = input_data.clone().detach()
35+
self.best_position = tf.identity(input_data) # Clone the input data
3836
self.best_score = -np.inf
39-
self.position = input_data.clone().detach()
40-
self.velocity = velocity if velocity is not None else torch.zeros_like(input_data)
37+
self.position = tf.identity(input_data) # Clone the input data
38+
self.velocity = velocity if velocity is not None else tf.zeros_like(input_data)
4139
self.history = []
4240

4341
# Class attributes
@@ -56,38 +54,37 @@ def fitness(self) -> float:
5654
Returns:
5755
float: Fitness score for this particle (higher is better).
5856
"""
59-
with torch.no_grad():
60-
output = self.model(self.position)
61-
probabilities = torch.softmax(output, dim=1) # Get probabilities for each class
62-
target_prob = probabilities[:, self.target_class] # Target class probability
63-
64-
return target_prob.item() # Return the target class probability as fitness score
57+
output = self.model(self.position) # Add batch dimension and pass through the model
58+
probabilities = tf.nn.softmax(output, axis=1) # Get probabilities for each class
59+
target_prob = probabilities[:, self.target_class] # Target class probability
60+
61+
return target_prob.numpy().item() # Return the target class probability as fitness score
6562

66-
def update_velocity(self, global_best_position: torch.Tensor) -> None:
63+
def update_velocity(self, global_best_position: tf.Tensor) -> None:
6764
"""
6865
Update the velocity of the particle based on the PSO update rule.
6966
7067
Args:
71-
global_best_position (torch.Tensor): The global best position found by the swarm.
68+
global_best_position (tf.Tensor): The global best position found by the swarm.
7269
"""
7370
inertia = self.inertia_weight * self.velocity
74-
cognitive = self.cognitive_weight * torch.rand_like(self.position) * (self.best_position - self.position)
75-
social = self.social_weight * torch.rand_like(self.position) * (global_best_position - self.position)
71+
cognitive = self.cognitive_weight * tf.random.uniform(self.position.shape) * (self.best_position - self.position)
72+
social = self.social_weight * tf.random.uniform(self.position.shape) * (global_best_position - self.position)
7673

7774
self.velocity = inertia + cognitive + social # Update velocity based on PSO formula
7875

7976
# Apply momentum and velocity clamping
8077
self.velocity = self.velocity * self.momentum # Apply momentum
81-
self.velocity = torch.clamp(self.velocity, -self.velocity_clamp, self.velocity_clamp) # Apply velocity clamp
78+
self.velocity = tf.clip_by_value(self.velocity, -self.velocity_clamp, self.velocity_clamp) # Apply velocity clamp
8279

8380
def update_position(self) -> None:
8481
"""
8582
Update the position of the particle based on the updated velocity.
8683
8784
Ensures that the position stays within the valid input range [0, 1] (normalized pixel values).
8885
"""
89-
self.position = torch.clamp(self.position + self.velocity, 0, 1) # Ensure position stays within bounds
90-
self.history.append(self.position.clone().detach())
86+
self.position = tf.clip_by_value(self.position + self.velocity, 0.0, 1.0) # Ensure position stays within bounds
87+
self.history.append(tf.identity(self.position)) # Store the position history
9188

9289
def evaluate(self) -> None:
9390
"""
@@ -99,4 +96,4 @@ def evaluate(self) -> None:
9996
score = self.fitness() # Get the current fitness score based on the perturbation
10097
if score > self.best_score: # If score is better than the personal best, update the best position
10198
self.best_score = score
102-
self.best_position = self.position.clone().detach()
99+
self.best_position = tf.identity(self.position) # Clone the current position

Adversarial_Observation/Swarm.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import os
22
import logging
3-
import torch
4-
from torch import nn
5-
from torchvision import utils
63
from typing import List
74
from Adversarial_Observation.BirdParticle import BirdParticle
85

6+
import torch
7+
from torchvision import utils
8+
9+
import tensorflow as tf
10+
import numpy as np
911

1012
class ParticleSwarm:
1113
"""
@@ -15,16 +17,16 @@ class ParticleSwarm:
1517
to misclassify it into the target class.
1618
"""
1719

18-
def __init__(self, model: nn.Module, input_set: torch.Tensor, target_class: int,
20+
def __init__(self, model: tf.keras.Model, input_set: np.ndarray, target_class: int,
1921
num_iterations: int = 20, epsilon: float = 0.8, save_dir: str = 'results',
2022
inertia_weight: float = 0.5, cognitive_weight: float = .5, social_weight: float = .5,
2123
momentum: float = 0.9, velocity_clamp: float = 0.1):
2224
"""
2325
Initialize the Particle Swarm Optimization (PSO) for adversarial attacks.
2426
2527
Args:
26-
model (nn.Module): The model to attack.
27-
input_set (torch.Tensor): The batch of input images to attack.
28+
model (tf.keras.Model): The model to attack.
29+
input_set (np.ndarray): The batch of input images to attack as a NumPy array.
2830
target_class (int): The target class for misclassification.
2931
num_iterations (int): The number of optimization iterations.
3032
epsilon (float): The perturbation bound.
@@ -36,20 +38,20 @@ def __init__(self, model: nn.Module, input_set: torch.Tensor, target_class: int,
3638
velocity_clamp (float): The velocity clamp to limit the velocity.
3739
"""
3840
self.model = model
39-
self.input_set = input_set # The batch of input images
41+
self.input_set = tf.convert_to_tensor(input_set, dtype=tf.float32) # Convert NumPy array to TensorFlow tensor
4042
self.target_class = target_class # The target class index
4143
self.num_iterations = num_iterations
4244
self.epsilon = epsilon # Perturbation bound
4345
self.save_dir = save_dir # Directory to save perturbed images
4446

4547
self.particles: List[BirdParticle] = [
46-
BirdParticle(model, input_set[i:i + 1], target_class, epsilon,
48+
BirdParticle(model, self.input_set[i:i + 1], target_class, epsilon,
4749
inertia_weight=inertia_weight, cognitive_weight=cognitive_weight,
4850
social_weight=social_weight, momentum=momentum, velocity_clamp=velocity_clamp)
4951
for i in range(len(input_set))
5052
]
5153

52-
self.global_best_position = torch.zeros_like(self.input_set[0]) # Global best position
54+
self.global_best_position = tf.zeros_like(self.input_set[0]) # Global best position
5355
self.global_best_score = -float('inf') # Initialize with a very low score
5456

5557
self.fitness_history: List[float] = [] # History of fitness scores to track progress
@@ -101,32 +103,35 @@ def log_progress(self, iteration: int):
101103
self.logger.info(f"\n{'-'*60}")
102104
self.logger.info(f"Iteration {iteration + 1}/{self.num_iterations}")
103105
self.logger.info(f"{'='*60}")
104-
106+
105107
# Table header
106108
header = f"{'Particle':<10}{'Original Pred':<15}{'Perturbed Pred':<18}{'Orig Target Prob':<20}" \
107109
f"{'Pert Target Prob':<20}{'Personal Best':<20}{'Global Best':<20}"
108110
self.logger.info(header)
109111
self.logger.info(f"{'-'*60}")
110-
112+
111113
# Log particle information
112114
for i, particle in enumerate(self.particles):
113-
with torch.no_grad():
114-
original_output = self.model(particle.original_data)
115-
perturbed_output = self.model(particle.position)
116-
117-
original_pred = torch.argmax(original_output, dim=1).item()
118-
perturbed_pred = torch.argmax(perturbed_output, dim=1).item()
119-
120-
original_probs = torch.softmax(original_output, dim=1)
121-
perturbed_probs = torch.softmax(perturbed_output, dim=1)
122-
123-
original_prob_target = original_probs[0, self.target_class].item()
124-
perturbed_prob_target = perturbed_probs[0, self.target_class].item()
125-
115+
# Get original and perturbed outputs
116+
original_output = self.model(particle.original_data) # Pass through the model
117+
perturbed_output = self.model(particle.position) # Pass through the model
118+
119+
# Get predicted classes
120+
original_pred = tf.argmax(original_output, axis=1).numpy().item()
121+
perturbed_pred = tf.argmax(perturbed_output, axis=1).numpy().item()
122+
123+
# Get softmax probabilities
124+
original_probs = tf.nn.softmax(original_output, axis=1)
125+
perturbed_probs = tf.nn.softmax(perturbed_output, axis=1)
126+
127+
# Get target class probabilities
128+
original_prob_target = original_probs[0, self.target_class].numpy().item()
129+
perturbed_prob_target = perturbed_probs[0, self.target_class].numpy().item()
130+
126131
# Log each particle's data in a formatted row
127132
self.logger.info(f"{i+1:<10}{original_pred:<15}{perturbed_pred:<18}{original_prob_target:<20.4f}"
128133
f"{perturbed_prob_target:<20.4f}{particle.best_score:<20.4f}{self.global_best_score:<20.4f}")
129-
134+
130135
self.logger.info(f"{'='*60}")
131136

132137
def save_images(self, iteration: int):
@@ -140,7 +145,12 @@ def save_images(self, iteration: int):
140145
os.makedirs(iteration_folder, exist_ok=True)
141146

142147
for i, particle in enumerate(self.particles):
143-
utils.save_image(particle.position, os.path.join(iteration_folder, f"perturbed_image_{i + 1}.png"))
148+
# Convert TensorFlow tensor to NumPy array
149+
position_numpy = particle.position.numpy()
150+
# Remove extra batch dimension (if it exists)
151+
position_numpy = np.squeeze(position_numpy) # Now shape is (28, 28)
152+
position_numpy = np.expand_dims(position_numpy, axis=-1) # Shape becomes (28, 28, 1)
153+
tf.keras.preprocessing.image.save_img(os.path.join(iteration_folder, f"perturbed_image_{i + 1}.png"), position_numpy)
144154

145155
def optimize(self):
146156
"""
@@ -158,7 +168,7 @@ def optimize(self):
158168
best_particle = max(self.particles, key=lambda p: p.best_score)
159169
if best_particle.best_score > self.global_best_score:
160170
self.global_best_score = best_particle.best_score
161-
self.global_best_position = best_particle.best_position.clone().detach()
171+
self.global_best_position = tf.identity(best_particle.best_position)
162172

163173
self.save_images(iteration)
164174

0 commit comments

Comments
 (0)