@@ -43,16 +43,16 @@ class BeyondDetectorPyTorch(EvasionDetector):
4343 | Paper link: https://openreview.net/pdf?id=S4LqI6CcJ3
4444 """
4545
46- defence_params = ["target_model" , "ssl_model" , "augmentations" , "aug_num" , "alpha" , "K " , "percentile" ]
46+ defence_params = ["target_model" , "ssl_model" , "augmentations" , "aug_num" , "alpha" , "var_K " , "percentile" ]
4747
4848 def __init__ (
4949 self ,
5050 target_classifier : "CLASSIFIER_NEURALNETWORK_TYPE" ,
5151 ssl_classifier : "CLASSIFIER_NEURALNETWORK_TYPE" ,
52- augmentations : Callable | None ,
52+ augmentations : Callable ,
5353 aug_num : int = 50 ,
5454 alpha : float = 0.8 ,
55- K : int = 20 ,
55+ var_K : int = 20 ,
5656 percentile : int = 5 ,
5757 ) -> None :
5858 """
@@ -63,7 +63,7 @@ def __init__(
6363 :param augmentations: data augmentations for generating neighborhoods
6464 :param aug_num: Number of augmentations to apply to each sample (default: 50)
6565 :param alpha: Weight factor for combining label and representation similarities (default: 0.8)
66- :param K : Number of top similarities to consider (default: 20)
66+ :param var_K : Number of top similarities to consider (default: 20)
6767 :param percentile: using to calculate the threshold
6868 """
6969 import torch
@@ -75,7 +75,7 @@ def __init__(
7575 self .ssl_model = ssl_classifier .model .to (self .device )
7676 self .aug_num = aug_num
7777 self .alpha = alpha
78- self .K = K
78+ self .var_K = var_K
7979
8080 self .backbone = self .ssl_model .backbone
8181 self .model_classifier = self .ssl_model .classifier
@@ -111,7 +111,7 @@ def _get_metrics(self, x: np.ndarray, batch_size: int = 128) -> np.ndarray:
111111
112112 number_batch = int (math .ceil (len (samples ) / batch_size ))
113113
114- similarities = []
114+ similarities_list = []
115115
116116 with torch .no_grad ():
117117 for index in range (number_batch ):
@@ -143,11 +143,11 @@ def _get_metrics(self, x: np.ndarray, batch_size: int = 128) -> np.ndarray:
143143 dim = 2 ,
144144 )
145145
146- similarities .append (
146+ similarities_list .append (
147147 (self .alpha * sim_preds + (1 - self .alpha ) * sim_repre ).sort (descending = True )[0 ].cpu ().numpy ()
148148 )
149149
150- similarities = np .concatenate (similarities , axis = 0 )
150+ similarities = np .concatenate (similarities_list , axis = 0 )
151151
152152 return similarities
153153
@@ -161,10 +161,10 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in
161161 :param nb_epochs: Number of training epochs (not used in this method)
162162 """
163163 clean_metrics = self ._get_metrics (x = x , batch_size = batch_size )
164- k_minus_one_metrics = clean_metrics [:, self .K - 1 ]
164+ k_minus_one_metrics = clean_metrics [:, self .var_K - 1 ]
165165 self .threshold = np .percentile (k_minus_one_metrics , q = self .percentile )
166166
167- def detect (self , x : np .ndarray , batch_size : int = 128 , ** kwargs ) -> tuple [dict , np .ndarray ]:
167+ def detect (self , x : np .ndarray , batch_size : int = 128 , ** kwargs ) -> tuple [np . ndarray , np .ndarray ]:
168168 """
169169 Detect whether given samples are adversarial
170170
@@ -179,7 +179,7 @@ def detect(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> tuple[dict,
179179
180180 similarities = self ._get_metrics (x , batch_size )
181181
182- report = similarities [:, self .K - 1 ]
182+ report = similarities [:, self .var_K - 1 ]
183183 is_adversarial = report < self .threshold
184184
185185 return report , is_adversarial
0 commit comments