2525from __future__ import absolute_import , division , print_function , unicode_literals
2626
2727import logging
28- from typing import Tuple , TYPE_CHECKING
28+ from typing import Optional , Tuple , TYPE_CHECKING
2929
3030import numpy as np
3131import scipy
3838
3939if TYPE_CHECKING :
4040 import torch
41- from torch .optim import Optimizer
4241
4342logger = logging .getLogger (__name__ )
4443
@@ -52,8 +51,6 @@ class ImperceptibleASRPytorch(EvasionAttack):
5251 | Paper link: https://arxiv.org/abs/1903.10346
5352 """
5453
55- import torch # lgtm [py/repeated-import]
56-
5754 attack_params = EvasionAttack .attack_params + [
5855 "initial_eps" ,
5956 "max_iter_1st_stage" ,
@@ -94,8 +91,8 @@ def __init__(
9491 max_iter_2nd_stage : int = 4000 ,
9592 learning_rate_1st_stage : float = 0.1 ,
9693 learning_rate_2nd_stage : float = 0.001 ,
97- optimizer_1st_stage : " Optimizer" = torch . optim . SGD ,
98- optimizer_2nd_stage : " Optimizer" = torch . optim . SGD ,
94+ optimizer_1st_stage : Optional [ "torch.optim. Optimizer"] = None ,
95+ optimizer_2nd_stage : Optional [ "torch.optim. Optimizer"] = None ,
9996 global_max_length : int = 10000 ,
10097 initial_rescale : float = 1.0 ,
10198 rescale_factor : float = 0.8 ,
@@ -123,8 +120,10 @@ def __init__(
123120 the attack.
124121 :param learning_rate_2nd_stage: The initial learning rate applied for the second stage of the optimization of
125122 the attack.
126- :param optimizer_1st_stage: The optimizer applied for the first stage of the optimization of the attack.
127- :param optimizer_2nd_stage: The optimizer applied for the second stage of the optimization of the attack.
123+ :param optimizer_1st_stage: The optimizer applied for the first stage of the optimization of the attack. If
124+ `None` attack will use `torch.optim.SGD`.
125+ :param optimizer_2nd_stage: The optimizer applied for the second stage of the optimization of the attack. If
126+ `None` attack will use `torch.optim.SGD`.
128127 :param global_max_length: The length of the longest audio signal allowed by this attack.
129128 :param initial_rescale: Initial rescale coefficient to speedup the decrease of the perturbation size during
130129 the first stage of the optimization of the attack.
@@ -189,12 +188,22 @@ def __init__(
189188 self .global_optimal_delta .to (self .estimator .device )
190189
191190 # Create the optimizers
192- self .optimizer_1st_stage = optimizer_1st_stage (
193- params = [self .global_optimal_delta ], lr = self .learning_rate_1st_stage
194- )
195- self .optimizer_2nd_stage = optimizer_2nd_stage (
196- params = [self .global_optimal_delta ], lr = self .learning_rate_1st_stage
197- )
191+ if optimizer_1st_stage is None :
192+ self .optimizer_1st_stage = torch .optim .SGD (
193+ params = [self .global_optimal_delta ], lr = self .learning_rate_1st_stage
194+ )
195+ else :
196+ self .optimizer_1st_stage = optimizer_1st_stage (
197+ params = [self .global_optimal_delta ], lr = self .learning_rate_1st_stage
198+ )
199+ if optimizer_2nd_stage is None :
200+ self .optimizer_2nd_stage = torch .optim .SGD (
201+ params = [self .global_optimal_delta ], lr = self .learning_rate_1st_stage
202+ )
203+ else :
204+ self .optimizer_2nd_stage = optimizer_2nd_stage (
205+ params = [self .global_optimal_delta ], lr = self .learning_rate_1st_stage
206+ )
198207
199208 # Setup for AMP use
200209 if self ._use_amp :
0 commit comments