@@ -317,15 +317,12 @@ class only supports targeted attack.
317317 theta_batch = []
318318 original_max_psd_batch = []
319319
320- for i in range ( len ( x ) ):
321- theta , original_max_psd = self ._compute_masking_threshold (original_input [ i ] )
320+ for _ , x_i in enumerate ( x ):
321+ theta , original_max_psd = self ._compute_masking_threshold (x_i )
322322 theta = theta .transpose (1 , 0 )
323323 theta_batch .append (theta )
324324 original_max_psd_batch .append (original_max_psd )
325325
326- theta_batch = np .array (theta_batch )
327- original_max_psd_batch = np .array (original_max_psd_batch )
328-
329326 # Reset delta with new result
330327 local_batch_shape = successful_adv_input_1st_stage .shape
331328 self .global_optimal_delta .data = torch .zeros (self .batch_size , self .global_max_length ).type (torch .float64 )
@@ -485,7 +482,7 @@ def _forward_1st_stage(
485482 return loss , local_delta , decoded_output , masked_adv_input , local_delta_rescale
486483
487484 def _attack_2nd_stage (
488- self , x : np .ndarray , y : np .ndarray , theta_batch : np .ndarray , original_max_psd_batch : np .ndarray
485+ self , x : np .ndarray , y : np .ndarray , theta_batch : List [ np .ndarray ] , original_max_psd_batch : List [ np .ndarray ]
489486 ) -> "torch.Tensor" :
490487 """
491488 The second stage of the attack.
@@ -544,6 +541,7 @@ class only supports targeted attack.
544541 local_delta_rescale = local_delta_rescale ,
545542 theta_batch = theta_batch ,
546543 original_max_psd_batch = original_max_psd_batch ,
544+ real_lengths = real_lengths ,
547545 )
548546
549547 # Total loss
@@ -597,15 +595,17 @@ class only supports targeted attack.
597595 def _forward_2nd_stage (
598596 self ,
599597 local_delta_rescale : "torch.Tensor" ,
600- theta_batch : np .ndarray ,
601- original_max_psd_batch : np .ndarray ,
598+ theta_batch : List [np .ndarray ],
599+ original_max_psd_batch : List [np .ndarray ],
600+ real_lengths : np .ndarray ,
602601 ) -> "torch.Tensor" :
603602 """
604603 The forward pass of the second stage of the attack.
605604
606605 :param local_delta_rescale: Local delta after rescaled.
607606 :param theta_batch: Original thresholds.
608607 :param original_max_psd_batch: Original maximum psd.
608+ :param real_lengths: Real lengths of original sequences.
609609 :return: The loss tensor of the second stage of the attack.
610610 """
611611 import torch # lgtm [py/repeated-import]
@@ -616,7 +616,7 @@ def _forward_2nd_stage(
616616
617617 for i , _ in enumerate (theta_batch ):
618618 psd_transform_delta = self ._psd_transform (
619- delta = local_delta_rescale [i , :], original_max_psd = original_max_psd_batch [i ]
619+ delta = local_delta_rescale [i , : real_lengths [ i ] ], original_max_psd = original_max_psd_batch [i ]
620620 )
621621
622622 loss = torch .mean (relu (psd_transform_delta - torch .tensor (theta_batch [i ]).to (self .estimator .device )))
0 commit comments