Skip to content

Commit e514e5d

Browse files
committed
Merge branch 'allenhzy-main' into dev_1.19.0
2 parents 6c57e03 + 5ec99dc commit e514e5d

File tree

6 files changed

+368
-0
lines changed

6 files changed

+368
-0
lines changed

art/defences/detector/evasion/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
from art.defences.detector.evasion.binary_input_detector import BinaryInputDetector
77
from art.defences.detector.evasion.binary_activation_detector import BinaryActivationDetector
88
from art.defences.detector.evasion.subsetscanning.detector import SubsetScanningDetector
9+
from art.defences.detector.evasion.beyond_detector import BeyondDetectorPyTorch
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# MIT License
2+
#
3+
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2024
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
6+
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
7+
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
8+
# persons to whom the Software is furnished to do so, subject to the following conditions:
9+
#
10+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
11+
# Software.
12+
#
13+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
14+
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
15+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
16+
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
17+
# SOFTWARE.
18+
"""
19+
This module implements the BEYOND detector for adversarial examples detection.
20+
21+
| Paper link: https://openreview.net/pdf?id=S4LqI6CcJ3
22+
"""
23+
from __future__ import annotations
24+
25+
import math
26+
from typing import TYPE_CHECKING, Callable
27+
28+
import numpy as np
29+
30+
if TYPE_CHECKING:
31+
import torch
32+
from art.utils import CLASSIFIER_NEURALNETWORK_TYPE
33+
34+
35+
from art.defences.detector.evasion.evasion_detector import EvasionDetector
36+
37+
38+
class BeyondDetectorPyTorch(EvasionDetector):
39+
"""
40+
BEYOND detector for adversarial samples detection.
41+
This detector uses a combination of SSL and target model predictions to detect adversarial examples.
42+
43+
| Paper link: https://openreview.net/pdf?id=S4LqI6CcJ3
44+
"""
45+
46+
defence_params = ["target_model", "ssl_model", "augmentations", "aug_num", "alpha", "K", "percentile"]
47+
48+
def __init__(
49+
self,
50+
target_classifier: "CLASSIFIER_NEURALNETWORK_TYPE",
51+
ssl_classifier: "CLASSIFIER_NEURALNETWORK_TYPE",
52+
augmentations: Callable | None,
53+
aug_num: int = 50,
54+
alpha: float = 0.8,
55+
K: int = 20,
56+
percentile: int = 5,
57+
) -> None:
58+
"""
59+
Initialize the BEYOND detector.
60+
61+
:param target_classifier: The target model to be protected
62+
:param ssl_classifier: The self-supervised learning model used for feature extraction
63+
:param augmentations: data augmentations for generating neighborhoods
64+
:param aug_num: Number of augmentations to apply to each sample (default: 50)
65+
:param alpha: Weight factor for combining label and representation similarities (default: 0.8)
66+
:param K: Number of top similarities to consider (default: 20)
67+
:param percentile: using to calculate the threshold
68+
"""
69+
import torch
70+
71+
super().__init__()
72+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73+
74+
self.target_model = target_classifier.model.to(self.device)
75+
self.ssl_model = ssl_classifier.model.to(self.device)
76+
self.aug_num = aug_num
77+
self.alpha = alpha
78+
self.K = K
79+
80+
self.backbone = self.ssl_model.backbone
81+
self.model_classifier = self.ssl_model.classifier
82+
self.projector = self.ssl_model.projector
83+
84+
self.img_augmentations = augmentations
85+
86+
self.percentile = percentile # determine the threshold
87+
self.threshold: float | None = None
88+
89+
def _multi_transform(self, img: "torch.Tensor") -> "torch.Tensor":
90+
import torch
91+
92+
return torch.stack([self.img_augmentations(img) for _ in range(self.aug_num)], dim=1)
93+
94+
def _get_metrics(self, x: np.ndarray, batch_size: int = 128) -> np.ndarray:
95+
"""
96+
Calculate similarities that combining label consistency and representation similarity for given samples
97+
98+
:param x: Input samples
99+
:param batch_size: Batch size for processing
100+
:return: A report similarities
101+
"""
102+
import torch
103+
import torch.nn.functional as F
104+
105+
samples = torch.from_numpy(x).to(self.device)
106+
107+
self.target_model.eval()
108+
self.backbone.eval()
109+
self.model_classifier.eval()
110+
self.projector.eval()
111+
112+
number_batch = int(math.ceil(len(samples) / batch_size))
113+
114+
similarities = []
115+
116+
with torch.no_grad():
117+
for index in range(number_batch):
118+
start = index * batch_size
119+
end = min((index + 1) * batch_size, len(samples))
120+
121+
batch_samples = samples[start:end]
122+
b, c, h, w = batch_samples.shape
123+
124+
trans_images = self._multi_transform(batch_samples).to(self.device)
125+
ssl_backbone_out = self.backbone(batch_samples)
126+
127+
ssl_repre = self.projector(ssl_backbone_out)
128+
ssl_pred = self.model_classifier(ssl_backbone_out)
129+
ssl_label = torch.max(ssl_pred, -1)[1]
130+
131+
aug_backbone_out = self.backbone(trans_images.reshape(-1, c, h, w))
132+
aug_repre = self.projector(aug_backbone_out)
133+
aug_pred = self.model_classifier(aug_backbone_out)
134+
aug_pred = aug_pred.reshape(b, self.aug_num, -1)
135+
136+
sim_repre = F.cosine_similarity(
137+
ssl_repre.unsqueeze(dim=1), aug_repre.reshape(b, self.aug_num, -1), dim=2
138+
)
139+
140+
sim_preds = F.cosine_similarity(
141+
F.one_hot(ssl_label, num_classes=ssl_pred.shape[-1]).unsqueeze(dim=1),
142+
aug_pred,
143+
dim=2,
144+
)
145+
146+
similarities.append(
147+
(self.alpha * sim_preds + (1 - self.alpha) * sim_repre).sort(descending=True)[0].cpu().numpy()
148+
)
149+
150+
similarities = np.concatenate(similarities, axis=0)
151+
152+
return similarities
153+
154+
def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 20, **kwargs) -> None:
155+
"""
156+
Determine a threshold that covers 95% of clean samples.
157+
158+
:param x: Clean sample data
159+
:param y: Clean sample labels (not used in this method)
160+
:param batch_size: Batch size for processing
161+
:param nb_epochs: Number of training epochs (not used in this method)
162+
"""
163+
clean_metrics = self._get_metrics(x=x, batch_size=batch_size)
164+
k_minus_one_metrics = clean_metrics[:, self.K - 1]
165+
self.threshold = np.percentile(k_minus_one_metrics, q=self.percentile)
166+
167+
def detect(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> tuple[dict, np.ndarray]:
168+
"""
169+
Detect whether given samples are adversarial
170+
171+
:param x: Input samples
172+
:param batch_size: Batch size for processing
173+
:return: (report, is_adversarial):
174+
where report containing detection results
175+
where is_adversarial is a boolean list indicating whether samples are adversarial or not
176+
"""
177+
if self.threshold is None:
178+
raise ValueError("Detector has not been fitted. Call fit() before detect().")
179+
180+
similarities = self._get_metrics(x, batch_size)
181+
182+
report = similarities[:, self.K - 1]
183+
is_adversarial = report < self.threshold
184+
185+
return report, is_adversarial

run_tests.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,10 @@ else
146146
"tests/defences/test_rounded.py" \
147147
"tests/defences/test_thermometer_encoding.py" \
148148
"tests/defences/test_variance_minimization.py" \
149+
"tests/defences/detector/evasion/test_beyond_detector.py" \
150+
"tests/defences/detector/evasion/test_binary_activation_detector.py" \
151+
"tests/defences/detector/evasion/test_binary_input_detector.py" \
152+
"tests/defences/detector/evasion/test_subsetscanning_detector.py" \
149153
"tests/defences/detector/poison/test_activation_defence.py" \
150154
"tests/defences/detector/poison/test_clustering_analyzer.py" \
151155
"tests/defences/detector/poison/test_ground_truth_evaluator.py" \
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# MIT License
2+
#
3+
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2024
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
6+
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
7+
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
8+
# persons to whom the Software is furnished to do so, subject to the following conditions:
9+
#
10+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
11+
# Software.
12+
#
13+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
14+
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
15+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
16+
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
17+
# SOFTWARE.
18+
from __future__ import absolute_import, division, print_function, unicode_literals
19+
20+
import pytest
21+
import numpy as np
22+
23+
from art.attacks.evasion.fast_gradient import FastGradientMethod
24+
from art.defences.detector.evasion import BeyondDetectorPyTorch
25+
from art.estimators.classification import PyTorchClassifier
26+
from tests.utils import ARTTestException
27+
28+
29+
def get_ssl_model(weights_path):
30+
"""
31+
Loads the SSL model (SimSiamWithCls).
32+
"""
33+
import torch
34+
import torch.nn as nn
35+
36+
class SimSiamWithCls(nn.Module):
37+
"""
38+
SimSiam with Classifier
39+
"""
40+
41+
def __init__(self, arch="resnet18", feat_dim=2048, num_proj_layers=2):
42+
from torchvision import models
43+
44+
super(SimSiamWithCls, self).__init__()
45+
self.backbone = models.resnet18()
46+
out_dim = self.backbone.fc.weight.shape[1]
47+
self.backbone.conv1 = nn.Conv2d(
48+
in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=2, bias=False
49+
)
50+
self.backbone.maxpool = nn.Identity()
51+
self.backbone.fc = nn.Identity()
52+
self.classifier = nn.Linear(out_dim, out_features=10)
53+
54+
pred_hidden_dim = int(feat_dim / 4)
55+
56+
self.projector = nn.Sequential(
57+
nn.Linear(out_dim, feat_dim, bias=False),
58+
nn.BatchNorm1d(feat_dim),
59+
nn.ReLU(),
60+
nn.Linear(feat_dim, feat_dim, bias=False),
61+
nn.BatchNorm1d(feat_dim),
62+
nn.ReLU(),
63+
nn.Linear(feat_dim, feat_dim),
64+
nn.BatchNorm1d(feat_dim, affine=False),
65+
)
66+
self.projector[6].bias.requires_grad = False
67+
68+
self.predictor = nn.Sequential(
69+
nn.Linear(feat_dim, pred_hidden_dim, bias=False),
70+
nn.BatchNorm1d(pred_hidden_dim),
71+
nn.ReLU(),
72+
nn.Linear(pred_hidden_dim, feat_dim),
73+
)
74+
75+
def forward(self, img, im_aug1=None, im_aug2=None):
76+
77+
r_ori = self.backbone(img)
78+
if im_aug1 is None and im_aug2 is None:
79+
cls = self.classifier(r_ori)
80+
rep = self.projector(r_ori)
81+
return {"cls": cls, "rep": rep}
82+
else:
83+
84+
r1 = self.backbone(im_aug1)
85+
r2 = self.backbone(im_aug2)
86+
87+
z1 = self.projector(r1)
88+
z2 = self.projector(r2)
89+
90+
p1 = self.predictor(z1)
91+
p2 = self.predictor(z2)
92+
93+
return {"z1": z1, "z2": z2, "p1": p1, "p2": p2}
94+
95+
model = SimSiamWithCls()
96+
model.load_state_dict(torch.load(weights_path))
97+
return model
98+
99+
100+
@pytest.mark.only_with_platform("pytorch")
101+
def test_beyond_detector(art_warning, get_default_cifar10_subset):
102+
try:
103+
import torch
104+
from torchvision import models, transforms
105+
106+
# Load CIFAR10 data
107+
(x_train, y_train), (x_test, _) = get_default_cifar10_subset
108+
109+
x_train = x_train[0:100]
110+
y_train = y_train[0:100]
111+
x_test = x_test[0:100]
112+
113+
# Load models
114+
# Download pretrained weights from
115+
# https://drive.google.com/drive/folders/1ieEdd7hOj2CIl1FQfu4-3RGZmEj-mesi?usp=sharing
116+
target_model = models.resnet18()
117+
# target_model.load_state_dict(torch.load("../../../../utils/resources/models/resnet_c10.pth", map_location=torch.device('cpu')))
118+
ssl_model = get_ssl_model(weights_path="../../../../utils/resources/models/simsiam_c10.pth")
119+
120+
target_classifier = PyTorchClassifier(
121+
model=target_model, nb_classes=10, input_shape=(3, 32, 32), loss=torch.nn.CrossEntropyLoss()
122+
)
123+
ssl_classifier = PyTorchClassifier(
124+
model=ssl_model, nb_classes=10, input_shape=(3, 32, 32), loss=torch.nn.CrossEntropyLoss()
125+
)
126+
127+
# Generate adversarial samples
128+
attack = FastGradientMethod(estimator=target_classifier, eps=0.05)
129+
x_test_adv = attack.generate(x_test)
130+
131+
img_augmentations = transforms.Compose(
132+
[
133+
transforms.RandomResizedCrop(32, scale=(0.2, 1.0)),
134+
transforms.RandomHorizontalFlip(),
135+
transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), # not strengthened
136+
transforms.RandomGrayscale(p=0.2),
137+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
138+
]
139+
)
140+
141+
# Initialize BeyondDetector
142+
detector = BeyondDetectorPyTorch(
143+
target_classifier=target_classifier,
144+
ssl_classifier=ssl_classifier,
145+
augmentations=img_augmentations,
146+
aug_num=50,
147+
alpha=0.8,
148+
K=20,
149+
percentile=5,
150+
)
151+
152+
# Fit the detector
153+
detector.fit(x_train, y_train, batch_size=128)
154+
155+
# Apply detector on clean and adversarial test data
156+
_, test_detection = detector.detect(x_test)
157+
_, test_adv_detection = detector.detect(x_test_adv)
158+
159+
# Assert there is at least one true positive and negative
160+
nb_true_positives = np.sum(test_adv_detection)
161+
nb_true_negatives = len(test_detection) - np.sum(test_detection)
162+
163+
assert nb_true_positives > 0
164+
assert nb_true_negatives > 0
165+
166+
clean_accuracy = 1 - np.mean(test_detection)
167+
adv_accuracy = np.mean(test_adv_detection)
168+
169+
assert clean_accuracy > 0.0
170+
assert adv_accuracy > 0.0
171+
172+
except ARTTestException as e:
173+
art_warning(e)
174+
175+
176+
if __name__ == "__main__":
177+
178+
test_beyond_detector()

utils/resources/models/resnet_c10.pth

42.7 MB
Binary file not shown.
86.8 MB
Binary file not shown.

0 commit comments

Comments
 (0)