2222"""
2323from __future__ import annotations
2424
25+ import math
26+ from typing import TYPE_CHECKING , Callable
27+
2528import numpy as np
26- from typing import TYPE_CHECKING
29+
2730if TYPE_CHECKING :
31+ import torch
2832 from art .utils import CLASSIFIER_NEURALNETWORK_TYPE
2933
3034
3135from art .defences .detector .evasion .evasion_detector import EvasionDetector
3236
33- class BeyondDetector (EvasionDetector ):
37+
38+ class BeyondDetectorPyTorch (EvasionDetector ):
3439 """
3540 BEYOND detector for adversarial samples detection.
3641 This detector uses a combination of SSL and target model predictions to detect adversarial examples.
37-
42+
3843 | Paper link: https://openreview.net/pdf?id=S4LqI6CcJ3
3944 """
40-
45+
4146 defence_params = ["target_model" , "ssl_model" , "augmentations" , "aug_num" , "alpha" , "K" , "percentile" ]
4247
43- def __init__ (self ,
44- target_model : "CLASSIFIER_NEURALNETWORK_TYPE" ,
45- ssl_model : "CLASSIFIER_NEURALNETWORK_TYPE" ,
48+ def __init__ (
49+ self ,
50+ target_classifier : "CLASSIFIER_NEURALNETWORK_TYPE" ,
51+ ssl_classifier : "CLASSIFIER_NEURALNETWORK_TYPE" ,
4652 augmentations : Callable | None ,
47- aug_num : int = 50 ,
48- alpha : float = 0.8 ,
49- K :int = 20 ,
50- percentile :int = 5 ) -> None :
53+ aug_num : int = 50 ,
54+ alpha : float = 0.8 ,
55+ K : int = 20 ,
56+ percentile : int = 5 ,
57+ ) -> None :
5158 """
5259 Initialize the BEYOND detector.
5360
54- :param target_model : The target model to be protected
55- :param ssl_model : The self-supervised learning model used for feature extraction
56- :param augmentation : data augmentations for generating neighborhoods
61+ :param target_classifier : The target model to be protected
62+ :param ssl_classifier : The self-supervised learning model used for feature extraction
63+ :param augmentations : data augmentations for generating neighborhoods
5764 :param aug_num: Number of augmentations to apply to each sample (default: 50)
5865 :param alpha: Weight factor for combining label and representation similarities (default: 0.8)
5966 :param K: Number of top similarities to consider (default: 20)
6067 :param percentile: using to calculate the threshold
6168 """
69+ import torch
70+
6271 super ().__init__ ()
6372 self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
6473
65- self .target_model = target_model .to (self .device )
66- self .ssl_model = ssl_model .to (self .device )
74+ self .target_model = target_classifier . model .to (self .device )
75+ self .ssl_model = ssl_classifier . model .to (self .device )
6776 self .aug_num = aug_num
6877 self .alpha = alpha
6978 self .K = K
7079
71- self .backbone = ssl_model .backbone
72- self .classifier = ssl_model .classifier
73- self .projector = ssl_model .projector
80+ self .backbone = self . ssl_model .backbone
81+ self .model_classifier = self . ssl_model .classifier
82+ self .projector = self . ssl_model .projector
7483
7584 self .img_augmentations = augmentations
7685
77- self .percentile = percentile # determinate the threshold
78- self .threshold = None
86+ self .percentile = percentile # determine the threshold
87+ self .threshold : float | None = None
88+
89+ def _multi_transform (self , img : "torch.Tensor" ) -> "torch.Tensor" :
90+ import torch
7991
80-
81-
82- def _multi_transform (self , img : torch .Tensor ) -> torch .Tensor :
8392 return torch .stack ([self .img_augmentations (img ) for _ in range (self .aug_num )], dim = 1 )
8493
85- def _get_metrics (self , x : np .ndarray , batch_size : int = 128 ) -> tuple [ dict , np .ndarray ] :
94+ def _get_metrics (self , x : np .ndarray , batch_size : int = 128 ) -> np .ndarray :
8695 """
8796 Calculate similarities that combining label consistency and representation similarity for given samples
8897
8998 :param x: Input samples
9099 :param batch_size: Batch size for processing
91100 :return: A report similarities
92101 """
102+ import torch
103+ import torch .nn .functional as F
104+
93105 samples = torch .from_numpy (x ).to (self .device )
94-
106+
95107 self .target_model .eval ()
96108 self .backbone .eval ()
97- self .classifier .eval ()
109+ self .model_classifier .eval ()
98110 self .projector .eval ()
99111
100112 number_batch = int (math .ceil (len (samples ) / batch_size ))
101-
113+
102114 similarities = []
103115
104116 with torch .no_grad ():
@@ -113,23 +125,31 @@ def _get_metrics(self, x: np.ndarray, batch_size: int = 128) -> tuple[dict, np.n
113125 ssl_backbone_out = self .backbone (batch_samples )
114126
115127 ssl_repre = self .projector (ssl_backbone_out )
116- ssl_pred = self .classifier (ssl_backbone_out )
128+ ssl_pred = self .model_classifier (ssl_backbone_out )
117129 ssl_label = torch .max (ssl_pred , - 1 )[1 ]
118130
119131 aug_backbone_out = self .backbone (trans_images .reshape (- 1 , c , h , w ))
120132 aug_repre = self .projector (aug_backbone_out )
121- aug_pred = self .classifier (aug_backbone_out )
133+ aug_pred = self .model_classifier (aug_backbone_out )
122134 aug_pred = aug_pred .reshape (b , self .aug_num , - 1 )
123135
124- sim_repre = F .cosine_similarity (ssl_repre .unsqueeze (dim = 1 ), aug_repre .reshape (b , self .aug_num , - 1 ), dim = 2 )
125- sim_preds = F .cosine_similarity (F .one_hot (torch .argmax (ssl_label , dim = 1 ), num_classes = ssl_pred .shape [- 1 ]).unsqueeze (dim = 1 ), aug_pred , dim = 2 )
136+ sim_repre = F .cosine_similarity (
137+ ssl_repre .unsqueeze (dim = 1 ), aug_repre .reshape (b , self .aug_num , - 1 ), dim = 2
138+ )
139+
140+ sim_preds = F .cosine_similarity (
141+ F .one_hot (ssl_label , num_classes = ssl_pred .shape [- 1 ]).unsqueeze (dim = 1 ),
142+ aug_pred ,
143+ dim = 2 ,
144+ )
126145
127- similarities .append ((self .alpha * sim_preds + (1 - self .alpha )* sim_repre ).sort (descending = True )[0 ].cpu ().numpy ())
146+ similarities .append (
147+ (self .alpha * sim_preds + (1 - self .alpha ) * sim_repre ).sort (descending = True )[0 ].cpu ().numpy ()
148+ )
128149
129150 similarities = np .concatenate (similarities , axis = 0 )
130-
131- return similarities
132151
152+ return similarities
133153
134154 def fit (self , x : np .ndarray , y : np .ndarray , batch_size : int = 128 , nb_epochs : int = 20 , ** kwargs ) -> None :
135155 """
@@ -140,26 +160,26 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in
140160 :param batch_size: Batch size for processing
141161 :param nb_epochs: Number of training epochs (not used in this method)
142162 """
143- k_minus_one_metrics = clean_metrics [:, self .K - 1 ]
144-
145- self .threshold = np .percentile (k_minus_one_metrics , self .threshold )
163+ clean_metrics = self ._get_metrics ( x = x , batch_size = batch_size )
164+ k_minus_one_metrics = clean_metrics [:, self . K - 1 ]
165+ self .threshold = np .percentile (k_minus_one_metrics , q = self .percentile )
146166
147167 def detect (self , x : np .ndarray , batch_size : int = 128 , ** kwargs ) -> tuple [dict , np .ndarray ]:
148168 """
149169 Detect whether given samples are adversarial
150-
170+
151171 :param x: Input samples
152172 :param batch_size: Batch size for processing
153173 :return: (report, is_adversarial):
154- where report containing detection results
174+ where report containing detection results
155175 where is_adversarial is a boolean list indicating whether samples are adversarial or not
156176 """
157177 if self .threshold is None :
158178 raise ValueError ("Detector has not been fitted. Call fit() before detect()." )
159-
179+
160180 similarities = self ._get_metrics (x , batch_size )
161-
162- report = similarities [:, self .K - 1 ]
181+
182+ report = similarities [:, self .K - 1 ]
163183 is_adversarial = report < self .threshold
164-
184+
165185 return report , is_adversarial
0 commit comments