@@ -332,8 +332,11 @@ def forward(
332332 B_score = 1 - self .cos (grad_ws_norm , d_w2_norm ) # pylint: disable=C0103
333333 return B_score , poisoned_samples
334334
335- x_trigger = torch .tensor (x_trigger , device = device , dtype = torch .float32 )
336- self .grad_ws_norm = _weight_grad (classifier , x_trigger , torch .tensor (y_trigger , device = device )).detach ()
335+ self .grad_ws_norm = _weight_grad (
336+ classifier ,
337+ torch .tensor (x_trigger , device = device , dtype = torch .float32 ),
338+ torch .tensor (y_trigger , device = device ),
339+ ).detach ()
337340 self .grad_ws_norm .requires_grad_ (False )
338341 self .backdoor_model = BackdoorModel (
339342 self ,
@@ -437,7 +440,7 @@ def poison(
437440 x_train [best_indices_poison ] = best_x_poisoned
438441 return x_train , y_train # y_train has not been modified.
439442
440- def __poison__pytorch (self , x_poison : np .ndarray , y_poison : np .ndarray ) -> np .ndarray :
443+ def __poison__pytorch (self , x_poison : np .ndarray , y_poison : np .ndarray ) -> Tuple [ np .ndarray , np . ndarray ] :
441444 """
442445 Optimize the poison by matching the gradient within the perturbation budget.
443446
@@ -507,7 +510,7 @@ def __len__(self):
507510 count += 1
508511 return np .concatenate (all_poisoned_samples , axis = 0 ), B_sum / count
509512
510- def __poison__tensorflow (self , x_poison : np .ndarray , y_poison : np .ndarray ) -> np .ndarray :
513+ def __poison__tensorflow (self , x_poison : np .ndarray , y_poison : np .ndarray ) -> Tuple [ np .ndarray , np . ndarray ] :
511514 """
512515 Optimize the poison by matching the gradient within the perturbation budget.
513516
0 commit comments