@@ -202,7 +202,9 @@ def calibrate_distance_threshold(
202202
203203 self .distance_threshold_tau = distance_threshold_tau
204204
205- def calibrate_distance_threshold_unsupervised (self , top_t : int , num_samples : int , max_queries : int , ** kwargs ):
205+ def calibrate_distance_threshold_unsupervised (
206+ self , top_t : int = 50 , num_samples : int = 100 , max_queries : int = 1 , ** kwargs
207+ ):
206208 """
207209 Calibrate distance threshold on randomly generated samples, choosing the top-t percentile of the noise needed
208210 to change the classifier's initial prediction.
@@ -211,7 +213,8 @@ def calibrate_distance_threshold_unsupervised(self, top_t: int, num_samples: int
211213
212214 :param top_t: Top-t percentile.
213215 :param num_samples: Number of random samples to generate.
214- :param max_queries: Maximum number of queries.
216+ :param max_queries: Maximum number of queries. Maximum number of HSJ iterations on a single sample will be
217+ max_queries * max_iter.
215218 :Keyword Arguments for HopSkipJump:
216219 * *norm*: Order of the norm. Possible values: "inf", np.inf or 2.
217220 * *max_iter*: Maximum number of iterations.
@@ -222,13 +225,18 @@ def calibrate_distance_threshold_unsupervised(self, top_t: int, num_samples: int
222225 """
223226 from art .attacks .evasion .hop_skip_jump import HopSkipJump
224227
225- x_min , x_max = self .estimator .clip_values
228+ if self .estimator .clip_values is not None :
229+ x_min , x_max = self .estimator .clip_values
230+ else :
231+ raise RuntimeError (
232+ "You need to set the estimator's clip_values in order to calibrate the distance threshold."
233+ )
226234
227235 x_rand = np .random .rand (* (num_samples ,) + self .estimator .input_shape ).astype (np .float32 )
228236 x_rand *= x_max - x_min # scale
229237 x_rand += x_min # shift
230238
231- y_rand = np . random . randint ( 0 , self .estimator .nb_classes , num_samples )
239+ y_rand = self .estimator .predict ( x = x_rand )
232240 y_rand = check_and_transform_label_format (y_rand , self .estimator .nb_classes )
233241
234242 hsj = HopSkipJump (classifier = self .estimator , targeted = False , ** kwargs )
0 commit comments