Skip to content

Commit de191be

Browse files
committed
DynamicBackdoorGAN implementation with workflow fix
Signed-off-by: Prachi Panwar <[email protected]>
1 parent 261b541 commit de191be

File tree

10 files changed

+321
-20
lines changed

10 files changed

+321
-20
lines changed

.github/workflows/ci-huggingface.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,14 @@ jobs:
5151
sudo apt-get update
5252
sudo apt-get -y -q install ffmpeg libavcodec-extra
5353
python -m pip install --upgrade pip setuptools wheel
54-
pip install -q -r <(sed '/^tensorflow/d;/^keras/d;/^torch/d;/^torchvision/d;/^torchaudio/d;/^transformers/d' requirements_test.txt)
54+
pip install -q -r <(sed '/^tensorflow/d;/^keras/d;/^torch/d;/^torchvision/d;/^torchaudio/d;/^transformers/d;/^safetensors/d' requirements_test.txt)
5555
pip install tensorflow==2.18.1
5656
pip install keras==3.10.0
5757
pip install torch==${{ matrix.torch }} --index-url https://download.pytorch.org/whl/cpu
5858
pip install torchvision==${{ matrix.torchvision }} --index-url https://download.pytorch.org/whl/cpu
5959
pip install torchaudio==${{ matrix.torchaudio }} --index-url https://download.pytorch.org/whl/cpu
6060
pip install transformers==${{ matrix.transformers }}
61+
pip install safetensors==0.5.3
6162
pip list
6263
6364
- name: Cache CIFAR-10 dataset

.github/workflows/ci-legacy.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ jobs:
3434
python: '3.10'
3535
tensorflow: 2.18.1
3636
keras: 3.10.0
37-
torch: 2.7.0
38-
torchvision: 0.22.0
39-
torchaudio: 2.7.0
37+
torch: 2.8.0
38+
torchvision: 0.23.0
39+
torchaudio: 2.8.0
4040
scikit-learn: 1.6.1
4141

4242
name: Run ${{ matrix.module }} ${{ matrix.name }} Tests

.github/workflows/ci-pytorch.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,18 @@ jobs:
2828
fail-fast: false
2929
matrix:
3030
include:
31-
- name: PyTorch 2.6.0 (Python 3.10)
32-
framework: pytorch
33-
python: '3.10'
34-
torch: 2.6.0
35-
torchvision: 0.21.0
36-
torchaudio: 2.6.0
3731
- name: PyTorch 2.7.1 (Python 3.10)
3832
framework: pytorch
3933
python: '3.10'
4034
torch: 2.7.1
4135
torchvision: 0.22.1
4236
torchaudio: 2.7.1
37+
- name: PyTorch 2.8.0 (Python 3.10)
38+
framework: pytorch
39+
python: '3.10'
40+
torch: 2.8.0
41+
torchvision: 0.23.0
42+
torchaudio: 2.8.0
4343

4444
name: ${{ matrix.name }}
4545
steps:
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
name: CI TensorFlow v1
2+
3+
on:
4+
pull_request:
5+
branches: [ main ]
6+
workflow_dispatch:
7+
8+
jobs:
9+
sanity:
10+
runs-on: ubuntu-latest
11+
steps:
12+
- name: Check out repo
13+
uses: actions/checkout@v4
14+
- name: Say hello
15+
run: echo "Workflow is wired up and running."

art/attacks/poisoning/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,5 @@
1919
from art.attacks.poisoning.hidden_trigger_backdoor.hidden_trigger_backdoor_pytorch import HiddenTriggerBackdoorPyTorch
2020
from art.attacks.poisoning.hidden_trigger_backdoor.hidden_trigger_backdoor_keras import HiddenTriggerBackdoorKeras
2121
from art.attacks.poisoning.sleeper_agent_attack import SleeperAgentAttack
22+
from art.attacks.poisoning.dynamic_backdoor_gan import DynamicBackdoorGAN
23+
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
2+
# Imports
3+
!pip install adversarial-robustness-toolbox
4+
import torch
5+
import torch.nn as nn
6+
import numpy as np
7+
from torch.utils.data import Subset
8+
from torchvision import datasets, transforms, models
9+
from art.estimators.classification import PyTorchClassifier
10+
from art.utils import to_categorical
11+
from art.attacks.poisoning import PoisoningAttackBackdoor
12+
13+
# Trigger Generator:A small CNN that learns to generate input-specific triggers
14+
class TriggerGenerator(nn.Module):
15+
def __init__(self, input_channels=3):
16+
super().__init__()
17+
self.net = nn.Sequential(
18+
nn.Conv2d(input_channels, 32, kernel_size=3, padding=1),
19+
nn.ReLU(),
20+
nn.Conv2d(32, 32, kernel_size=3, padding=1),
21+
nn.ReLU(),
22+
nn.Conv2d(32, input_channels, kernel_size=3, padding=1),
23+
nn.Tanh()
24+
)
25+
26+
def forward(self, x):
27+
return self.net(x)
28+
29+
# Custom Poisoning Attack: DynamicBackdoorGAN-This class defines how to poison data using the GAN trigger generator
30+
class DynamicBackdoorGAN(PoisoningAttackBackdoor):
31+
def __init__(self, generator, target_label, backdoor_rate, classifier, epsilon=0.5):
32+
super().__init__(perturbation=lambda x: x)
33+
self.classifier = classifier
34+
self.generator = generator.to(classifier.device)
35+
self.target_label = target_label
36+
self.backdoor_rate = backdoor_rate
37+
self.epsilon = epsilon
38+
# Add trigger to a given image batch
39+
def apply_trigger(self, images):
40+
self.generator.eval()
41+
with torch.no_grad():
42+
images = nn.functional.interpolate(images, size=(32, 32), mode='bilinear') # Resize images to ensure uniform dimension
43+
triggers = self.generator(images.to(self.classifier.device)) #Generate dynamic, input-specific triggers using the trained TriggerGenerator
44+
poisoned = (images.to(self.classifier.device) + self.epsilon * triggers).clamp(0, 1) # Clamp the pixel values to ensure they stay in the valid [0, 1] range.
45+
return poisoned
46+
# Poison the training data by injecting dynamic triggers and changing labels
47+
def poison(self, x, y):
48+
# Convert raw image data (x) to torch tensors (float), and convert one-hot labels (y) to class indices-required by ART
49+
x_tensor = torch.tensor(x).float()
50+
y_tensor = torch.tensor(np.argmax(y, axis=1))
51+
# Calculate total number of samples and how many should be poisoned(posion ratio=backdoor_rate)
52+
batch_size = x_tensor.shape[0]
53+
n_poison = int(self.backdoor_rate * batch_size)
54+
# Apply the learned trigger to the first 'n_poison' samples
55+
poisoned = self.apply_trigger(x_tensor[:n_poison])
56+
# The remaining samples remain clean
57+
clean = x_tensor[n_poison:].to(self.classifier.device)
58+
# Combine poisoned and clean samples into a single batch
59+
poisoned_images = torch.cat([poisoned, clean], dim=0).cpu().numpy()
60+
# Modify the labels of poisoned samples to the attacker's target class
61+
new_labels = y_tensor.clone()
62+
new_labels[:n_poison] = self.target_label # Set the poisoned labels to the desired misclassification
63+
# Convert all labels back to one-hot encoding (required by ART classifiers)
64+
new_labels = to_categorical(new_labels.numpy(), nb_classes=self.classifier.nb_classes)
65+
return poisoned_images.astype(np.float32), new_labels.astype(np.float32)
66+
#Evaluate the attack's success on test data
67+
def evaluate(self, x_clean, y_clean):
68+
x_tensor = torch.tensor(x_clean).float()
69+
poisoned_test = self.apply_trigger(x_tensor).cpu().numpy().astype(np.float32)# Apply the trigger to every test image to create a poisoned test set
70+
71+
preds = self.classifier.predict(poisoned_test)
72+
true_target = np.full((len(preds),), self.target_label)
73+
pred_labels = np.argmax(preds, axis=1)
74+
75+
success = np.sum(pred_labels == true_target)
76+
asr = 100.0 * success / len(pred_labels)
77+
return asr

art/estimators/certification/deep_z/pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def concrete_forward(self, in_x: np.ndarray | "torch.Tensor") -> "torch.Tensor":
169169
# as reshapes are not modules we infer when the reshape from convolutional to dense occurs
170170
if self.reshape_op_num == op_num:
171171
x = x.reshape((x.shape[0], -1))
172-
x = op.concrete_forward(x)
172+
x = op.concrete_forward(x) # type: ignore
173173
return x
174174

175175
def set_forward_mode(self, mode: str) -> None:

art/estimators/certification/interval/pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def concrete_forward(self, in_x: np.ndarray | "torch.Tensor") -> "torch.Tensor":
179179
if isinstance(op, PyTorchIntervalConv2D) and self.forward_mode == "attack":
180180
x = op.conv_forward(x)
181181
else:
182-
x = op.concrete_forward(x)
182+
x = op.concrete_forward(x) # type: ignore
183183
return x
184184

185185
def set_forward_mode(self, mode: str) -> None:

examples/dynamicbackdoorgan_demo.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# -*- coding: utf-8 -*-
2+
"""DynamicBackdoorGAN_Demo.ipynb
3+
4+
Automatically generated by Colab.
5+
6+
Original file is located at
7+
https://colab.research.google.com/drive/1aMV5GZ7Z0cwuUl36NxFUsBU5RoJunCGA
8+
"""
9+
10+
pip install adversarial-robustness-toolbox
11+
12+
# Imports
13+
import torch
14+
import torch.nn as nn
15+
import numpy as np
16+
from torch.utils.data import Subset
17+
from torchvision import datasets, transforms, models
18+
from art.estimators.classification import PyTorchClassifier
19+
from art.utils import to_categorical
20+
from art.attacks.poisoning import PoisoningAttackBackdoor
21+
22+
# User Config
23+
config = {
24+
"dataset": "CIFAR10", # CIFAR10, CIFAR100, MNIST
25+
"model_name": "resnet18", # resnet18, resnet50, mobilenetv2, densenet121
26+
"poison_ratio": 0.1,
27+
"target_label": 0, # Target label to which poisoned samples are mapped
28+
"epochs": 30,
29+
"batch_size": 128,
30+
"epsilon": 0.5 # Trigger strength
31+
}
32+
33+
# #Trigger Generator:A small CNN that learns to generate input-specific triggers
34+
class TriggerGenerator(nn.Module):
35+
def __init__(self, input_channels=3):
36+
super().__init__()
37+
self.net = nn.Sequential(
38+
nn.Conv2d(input_channels, 32, kernel_size=3, padding=1),
39+
nn.ReLU(),
40+
nn.Conv2d(32, 32, kernel_size=3, padding=1),
41+
nn.ReLU(),
42+
nn.Conv2d(32, input_channels, kernel_size=3, padding=1),
43+
nn.Tanh()
44+
)
45+
46+
def forward(self, x):
47+
return self.net(x)
48+
49+
# Custom Poisoning Attack: DynamicBackdoorGAN-This class defines how to poison data using the GAN trigger generator
50+
class DynamicBackdoorGAN(PoisoningAttackBackdoor):
51+
def __init__(self, generator, target_label, backdoor_rate, classifier, epsilon=0.5):
52+
super().__init__(perturbation=lambda x: x)
53+
self.classifier = classifier
54+
self.generator = generator.to(classifier.device)
55+
self.target_label = target_label
56+
self.backdoor_rate = backdoor_rate
57+
self.epsilon = epsilon
58+
# Add trigger to a given image batch
59+
def apply_trigger(self, images):
60+
self.generator.eval()
61+
with torch.no_grad():
62+
images = nn.functional.interpolate(images, size=(32, 32), mode='bilinear') # Resize images to ensure uniform dimension
63+
triggers = self.generator(images.to(self.classifier.device)) #Generate dynamic, input-specific triggers using the trained TriggerGenerator
64+
poisoned = (images.to(self.classifier.device) + self.epsilon * triggers).clamp(0, 1) # Clamp the pixel values to ensure they stay in the valid [0, 1] range.
65+
return poisoned
66+
# Poison the training data by injecting dynamic triggers and changing labels
67+
def poison(self, x, y):
68+
# Convert raw image data (x) to torch tensors (float), and convert one-hot labels (y) to class indices-required by ART
69+
x_tensor = torch.tensor(x).float()
70+
y_tensor = torch.tensor(np.argmax(y, axis=1))
71+
# Calculate total number of samples and how many should be poisoned(posion ratio=backdoor_rate)
72+
batch_size = x_tensor.shape[0]
73+
n_poison = int(self.backdoor_rate * batch_size)
74+
# Apply the learned trigger to the first 'n_poison' samples
75+
poisoned = self.apply_trigger(x_tensor[:n_poison])
76+
# The remaining samples remain clean
77+
clean = x_tensor[n_poison:].to(self.classifier.device)
78+
# Combine poisoned and clean samples into a single batch
79+
poisoned_images = torch.cat([poisoned, clean], dim=0).cpu().numpy()
80+
# Modify the labels of poisoned samples to the attacker's target class
81+
new_labels = y_tensor.clone()
82+
new_labels[:n_poison] = self.target_label # Set the poisoned labels to the desired misclassification
83+
# Convert all labels back to one-hot encoding (required by ART classifiers)
84+
new_labels = to_categorical(new_labels.numpy(), nb_classes=self.classifier.nb_classes)
85+
return poisoned_images.astype(np.float32), new_labels.astype(np.float32)
86+
#Evaluate the attack's success on test data
87+
def evaluate(self, x_clean, y_clean):
88+
x_tensor = torch.tensor(x_clean).float()
89+
poisoned_test = self.apply_trigger(x_tensor).cpu().numpy().astype(np.float32)# Apply the trigger to every test image to create a poisoned test set
90+
91+
preds = self.classifier.predict(poisoned_test)
92+
true_target = np.full((len(preds),), self.target_label)
93+
pred_labels = np.argmax(preds, axis=1)
94+
95+
success = np.sum(pred_labels == true_target)
96+
asr = 100.0 * success / len(pred_labels)
97+
return asr
98+
99+
# ✅ Utility: Load Data
100+
def get_data(dataset="CIFAR10", train_subset=None, test_subset=None):
101+
if dataset in ["CIFAR10", "CIFAR100"]:
102+
transform = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor()])
103+
elif dataset == "MNIST":
104+
transform = transforms.Compose([
105+
transforms.Grayscale(num_output_channels=3),
106+
transforms.Resize((32, 32)),
107+
transforms.ToTensor()
108+
])
109+
else:
110+
raise ValueError("Unsupported dataset")
111+
112+
if dataset == "CIFAR10":
113+
dataset_cls = datasets.CIFAR10
114+
num_classes = 10
115+
elif dataset == "CIFAR100":
116+
dataset_cls = datasets.CIFAR100
117+
num_classes = 100
118+
elif dataset == "MNIST":
119+
dataset_cls = datasets.MNIST
120+
num_classes = 10
121+
122+
train_set = dataset_cls(root="./data", train=True, download=True, transform=transform)
123+
test_set = dataset_cls(root="./data", train=False, download=True, transform=transform)
124+
125+
if train_subset is not None:
126+
train_set = Subset(train_set, range(train_subset))
127+
if test_subset is not None:
128+
test_set = Subset(test_set, range(test_subset))
129+
130+
x_train = torch.stack([x for x, _ in train_set]).numpy()
131+
y_train = to_categorical([y for _, y in train_set], nb_classes=num_classes)
132+
133+
x_test = torch.stack([x for x, _ in test_set]).numpy()
134+
y_test = to_categorical([y for _, y in test_set], nb_classes=num_classes)
135+
136+
return x_train, y_train, x_test, y_test, num_classes
137+
138+
# Utility: Get ART Classifier:Returns an ART-compatible classifier wrapped around a selected PyTorch model
139+
def get_classifier(config):
140+
model_name = config["model_name"]
141+
nb_classes = config["nb_classes"]
142+
input_shape = config["input_shape"]
143+
lr = config.get("learning_rate", 0.001)
144+
145+
if model_name == "resnet18":
146+
model = models.resnet18(num_classes=nb_classes)
147+
elif model_name == "resnet50":
148+
model = models.resnet50(num_classes=nb_classes)
149+
elif model_name == "mobilenetv2":
150+
model = models.mobilenet_v2(num_classes=nb_classes)
151+
elif model_name == "densenet121":
152+
model = models.densenet121(num_classes=nb_classes)
153+
else:
154+
raise ValueError(f"Unsupported model: {model_name}")
155+
156+
loss = torch.nn.CrossEntropyLoss()
157+
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
158+
159+
classifier = PyTorchClassifier(
160+
model=model,
161+
loss=loss,
162+
optimizer=optimizer,
163+
input_shape=input_shape,
164+
nb_classes=nb_classes,
165+
clip_values=(0.0, 1.0),
166+
device_type="gpu" if torch.cuda.is_available() else "cpu"
167+
)
168+
return classifier
169+
170+
# Full Experiment:Runs both clean training and poisoned training, and evaluates the effectiveness of the backdoor attack
171+
def run_dynamic_backdoor_experiment(config):
172+
x_train, y_train, x_test, y_test, num_classes = get_data(
173+
dataset=config["dataset"],
174+
train_subset=config.get("train_subset"),
175+
test_subset=config.get("test_subset")
176+
)
177+
config["nb_classes"] = num_classes
178+
config["input_shape"] = x_train.shape[1:]
179+
180+
classifier = get_classifier(config)
181+
182+
# Clean training
183+
classifier.fit(x_train, y_train, nb_epochs=config["epochs"], batch_size=config["batch_size"])
184+
clean_acc = np.mean(np.argmax(classifier.predict(x_test), axis=1) == np.argmax(y_test, axis=1))
185+
print(f"Clean Accuracy: {clean_acc * 100:.2f}%")
186+
187+
# Poison training
188+
generator = TriggerGenerator()
189+
attack = DynamicBackdoorGAN(
190+
generator,
191+
config["target_label"],
192+
config["poison_ratio"],
193+
classifier,
194+
epsilon=config["epsilon"]
195+
)
196+
x_poison, y_poison = attack.poison(x_train, y_train)
197+
198+
classifier.fit(x_poison, y_poison, nb_epochs=config["epochs"], batch_size=config["batch_size"])
199+
poisoned_acc = np.mean(np.argmax(classifier.predict(x_test), axis=1) == np.argmax(y_test, axis=1))
200+
print(f"Poisoned Accuracy: {poisoned_acc * 100:.2f}%")
201+
202+
asr = attack.evaluate(x_test, y_test)
203+
print(f" Attack Success Rate (ASR): {asr:.2f}%")
204+
205+
# ✅ Run
206+
run_dynamic_backdoor_experiment(config)

0 commit comments

Comments
 (0)