@@ -61,6 +61,7 @@ def __init__(
6161 max_iter : int ,
6262 num_grid : int ,
6363 batch_size : int ,
64+ threshold : float ,
6465 ) -> None :
6566 """
6667 Create an overload attack instance.
@@ -70,12 +71,14 @@ def __init__(
7071 :param max_iter: The maximum number of iterations.
7172 :param num_grid: The number of grids for width and high dimension.
7273 :param batch_size: Size of the batch on which adversarial samples are generated.
74+ :param threshold: IoU threshold.
7375 """
7476 super ().__init__ (estimator = estimator )
7577 self .eps = eps
7678 self .max_iter = max_iter
7779 self .num_grid = num_grid
7880 self .batch_size = batch_size
81+ self .threshold = threshold
7982 self ._check_params ()
8083
8184 def generate (self , x : np .ndarray , y : np .ndarray | None = None , ** kwargs ) -> np .ndarray :
@@ -157,10 +160,9 @@ def _loss(self, x: "torch.Tensor") -> tuple["torch.Tensor", "torch.Tensor"]:
157160 if isinstance (adv_logits , tuple ):
158161 adv_logits = adv_logits [0 ]
159162
160- threshold = self .estimator .model .conf
161163 conf = adv_logits [..., 4 ]
162164 prob = adv_logits [..., 5 :]
163- prob = torch .where (conf [:, :, None ] * prob > threshold , torch .ones_like (prob ), prob )
165+ prob = torch .where (conf [:, :, None ] * prob > self . threshold , torch .ones_like (prob ), prob )
164166 prob = torch .sum (prob , dim = 2 )
165167 conf = conf * prob
166168
@@ -185,7 +187,7 @@ def _loss(self, x: "torch.Tensor") -> tuple["torch.Tensor", "torch.Tensor"]:
185187 for x_i in range (x .shape [0 ]):
186188 xyhw = adv_logits [x_i , :, :4 ]
187189 prob = torch .max (adv_logits [x_i , :, 5 :], dim = 1 ).values
188- box_idx = adv_logits [x_i , :, 4 ] * prob > threshold
190+ box_idx = adv_logits [x_i , :, 4 ] * prob > self . threshold
189191 xyhw = xyhw [box_idx ]
190192 c_xyxy = self .xywh2xyxy (xyhw )
191193 scores = box_iou (grid_box , c_xyxy )
@@ -244,3 +246,6 @@ def _check_params(self) -> None:
244246
245247 if self .batch_size < 1 :
246248 raise ValueError ("The batch size must be a positive integer." )
249+
250+ if self .threshold < 0.0 or self .threshold > 1.0 :
251+ raise ValueError ("The threshold must be in the range [0, 1]." )
0 commit comments