Skip to content

Commit 4a40dcb

Browse files
committed
Finalize integration of BEYOND detector
Signed-off-by: Beat Buesser <[email protected]>
1 parent e514e5d commit 4a40dcb

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

art/defences/detector/evasion/beyond_detector.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,16 @@ class BeyondDetectorPyTorch(EvasionDetector):
4343
| Paper link: https://openreview.net/pdf?id=S4LqI6CcJ3
4444
"""
4545

46-
defence_params = ["target_model", "ssl_model", "augmentations", "aug_num", "alpha", "K", "percentile"]
46+
defence_params = ["target_model", "ssl_model", "augmentations", "aug_num", "alpha", "var_K", "percentile"]
4747

4848
def __init__(
4949
self,
5050
target_classifier: "CLASSIFIER_NEURALNETWORK_TYPE",
5151
ssl_classifier: "CLASSIFIER_NEURALNETWORK_TYPE",
52-
augmentations: Callable | None,
52+
augmentations: Callable,
5353
aug_num: int = 50,
5454
alpha: float = 0.8,
55-
K: int = 20,
55+
var_K: int = 20,
5656
percentile: int = 5,
5757
) -> None:
5858
"""
@@ -63,7 +63,7 @@ def __init__(
6363
:param augmentations: data augmentations for generating neighborhoods
6464
:param aug_num: Number of augmentations to apply to each sample (default: 50)
6565
:param alpha: Weight factor for combining label and representation similarities (default: 0.8)
66-
:param K: Number of top similarities to consider (default: 20)
66+
:param var_K: Number of top similarities to consider (default: 20)
6767
:param percentile: using to calculate the threshold
6868
"""
6969
import torch
@@ -75,7 +75,7 @@ def __init__(
7575
self.ssl_model = ssl_classifier.model.to(self.device)
7676
self.aug_num = aug_num
7777
self.alpha = alpha
78-
self.K = K
78+
self.var_K = var_K
7979

8080
self.backbone = self.ssl_model.backbone
8181
self.model_classifier = self.ssl_model.classifier
@@ -111,7 +111,7 @@ def _get_metrics(self, x: np.ndarray, batch_size: int = 128) -> np.ndarray:
111111

112112
number_batch = int(math.ceil(len(samples) / batch_size))
113113

114-
similarities = []
114+
similarities_list = []
115115

116116
with torch.no_grad():
117117
for index in range(number_batch):
@@ -143,11 +143,11 @@ def _get_metrics(self, x: np.ndarray, batch_size: int = 128) -> np.ndarray:
143143
dim=2,
144144
)
145145

146-
similarities.append(
146+
similarities_list.append(
147147
(self.alpha * sim_preds + (1 - self.alpha) * sim_repre).sort(descending=True)[0].cpu().numpy()
148148
)
149149

150-
similarities = np.concatenate(similarities, axis=0)
150+
similarities = np.concatenate(similarities_list, axis=0)
151151

152152
return similarities
153153

@@ -161,10 +161,10 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in
161161
:param nb_epochs: Number of training epochs (not used in this method)
162162
"""
163163
clean_metrics = self._get_metrics(x=x, batch_size=batch_size)
164-
k_minus_one_metrics = clean_metrics[:, self.K - 1]
164+
k_minus_one_metrics = clean_metrics[:, self.var_K - 1]
165165
self.threshold = np.percentile(k_minus_one_metrics, q=self.percentile)
166166

167-
def detect(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> tuple[dict, np.ndarray]:
167+
def detect(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> tuple[np.ndarray, np.ndarray]:
168168
"""
169169
Detect whether given samples are adversarial
170170
@@ -179,7 +179,7 @@ def detect(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> tuple[dict,
179179

180180
similarities = self._get_metrics(x, batch_size)
181181

182-
report = similarities[:, self.K - 1]
182+
report = similarities[:, self.var_K - 1]
183183
is_adversarial = report < self.threshold
184184

185185
return report, is_adversarial

tests/defences/detector/evasion/test_beyond_detector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ def test_beyond_detector(art_warning, get_default_cifar10_subset):
114114
# Download pretrained weights from
115115
# https://drive.google.com/drive/folders/1ieEdd7hOj2CIl1FQfu4-3RGZmEj-mesi?usp=sharing
116116
target_model = models.resnet18()
117-
# target_model.load_state_dict(torch.load("../../../../utils/resources/models/resnet_c10.pth", map_location=torch.device('cpu')))
118-
ssl_model = get_ssl_model(weights_path="../../../../utils/resources/models/simsiam_c10.pth")
117+
# target_model.load_state_dict(torch.load("./utils/resources/models/resnet_c10.pth", map_location=torch.device('cpu')))
118+
ssl_model = get_ssl_model(weights_path="./utils/resources/models/simsiam_c10.pth")
119119

120120
target_classifier = PyTorchClassifier(
121121
model=target_model, nb_classes=10, input_shape=(3, 32, 32), loss=torch.nn.CrossEntropyLoss()

0 commit comments

Comments
 (0)