@@ -90,8 +90,8 @@ def __init__(
9090 max_iter_2 : int = 4000 ,
9191 learning_rate_1 : float = 0.001 ,
9292 learning_rate_2 : float = 5e-4 ,
93- optimizer_1 : "torch.optim.Optimizer" = "torch.optim.Adam" ,
94- optimizer_2 : "torch.optim.Optimizer" = "torch.optim.Adam" ,
93+ optimizer_1 : Optional [ "torch.optim.Optimizer" ] = None ,
94+ optimizer_2 : Optional [ "torch.optim.Optimizer" ] = None ,
9595 global_max_length : int = 200000 ,
9696 initial_rescale : float = 1.0 ,
9797 decrease_factor_eps : float = 0.8 ,
@@ -116,8 +116,10 @@ def __init__(
116116 attack.
117117 :param learning_rate_1: The learning rate applied for the first stage of the optimization of the attack.
118118 :param learning_rate_2: The learning rate applied for the second stage of the optimization of the attack.
119- :param optimizer_1: The optimizer applied for the first stage of the optimization of the attack.
120- :param optimizer_2: The optimizer applied for the second stage of the optimization of the attack.
119+ :param optimizer_1: The optimizer applied for the first stage of the optimization of the attack. If `None`
120+ attack will use `torch.optim.Adam`.
121+ :param optimizer_2: The optimizer applied for the second stage of the optimization of the attack. If `None`
122+ attack will use `torch.optim.Adam`.
121123 :param global_max_length: The length of the longest audio signal allowed by this attack.
122124 :param initial_rescale: Initial rescale coefficient to speedup the decrease of the perturbation size during
123125 the first stage of the optimization of the attack.
@@ -177,10 +179,16 @@ def __init__(
177179
178180 # Create the optimizers
179181 self ._optimizer_arg_1 = optimizer_1
180- self .optimizer_1 = self ._optimizer_arg_1 (params = [self .global_optimal_delta ], lr = self .learning_rate_1 )
182+ if self ._optimizer_arg_1 is None :
183+ self .optimizer_1 = torch .optim .Adam (params = [self .global_optimal_delta ], lr = self .learning_rate_1 )
184+ else :
185+ self .optimizer_1 = self ._optimizer_arg_1 (params = [self .global_optimal_delta ], lr = self .learning_rate_1 )
181186
182187 self ._optimizer_arg_2 = optimizer_2
183- self .optimizer_2 = self ._optimizer_arg_2 (params = [self .global_optimal_delta ], lr = self .learning_rate_2 )
188+ if self ._optimizer_arg_2 is None :
189+ self .optimizer_2 = torch .optim .Adam (params = [self .global_optimal_delta ], lr = self .learning_rate_2 )
190+ else :
191+ self .optimizer_2 = self ._optimizer_arg_2 (params = [self .global_optimal_delta ], lr = self .learning_rate_2 )
184192
185193 # Setup for AMP use
186194 if self ._use_amp :
@@ -247,8 +255,15 @@ class only supports targeted attack.
247255 self .global_optimal_delta .data = torch .zeros (self .batch_size , self .global_max_length ).type (torch .float64 )
248256
249257 # Next, reset optimizers
250- self .optimizer_1 = self ._optimizer_arg_1 (params = [self .global_optimal_delta ], lr = self .learning_rate_1 )
251- self .optimizer_2 = self ._optimizer_arg_2 (params = [self .global_optimal_delta ], lr = self .learning_rate_2 )
258+ if self ._optimizer_arg_1 is None :
259+ self .optimizer_1 = torch .optim .Adam (params = [self .global_optimal_delta ], lr = self .learning_rate_1 )
260+ else :
261+ self .optimizer_1 = self ._optimizer_arg_1 (params = [self .global_optimal_delta ], lr = self .learning_rate_1 )
262+
263+ if self ._optimizer_arg_2 is None :
264+ self .optimizer_2 = torch .optim .Adam (params = [self .global_optimal_delta ], lr = self .learning_rate_2 )
265+ else :
266+ self .optimizer_2 = self ._optimizer_arg_2 (params = [self .global_optimal_delta ], lr = self .learning_rate_2 )
252267
253268 # Then compute the batch
254269 adv_x_batch = self ._generate_batch (adv_x [batch_index_1 :batch_index_2 ], y [batch_index_1 :batch_index_2 ])
0 commit comments