Skip to content

Commit 2da6a9f

Browse files
minhitbkBeat Buesser
authored andcommitted
change type of optimizers
Signed-off-by: Ngoc Minh Tran <[email protected]>
1 parent 8ab59d9 commit 2da6a9f

File tree

1 file changed

+23
-8
lines changed

1 file changed

+23
-8
lines changed

art/attacks/evasion/imperceptible_asr/imperceptible_asr_pytorch.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ def __init__(
9090
max_iter_2: int = 4000,
9191
learning_rate_1: float = 0.001,
9292
learning_rate_2: float = 5e-4,
93-
optimizer_1: "torch.optim.Optimizer" = "torch.optim.Adam",
94-
optimizer_2: "torch.optim.Optimizer" = "torch.optim.Adam",
93+
optimizer_1: Optional["torch.optim.Optimizer"] = None,
94+
optimizer_2: Optional["torch.optim.Optimizer"] = None,
9595
global_max_length: int = 200000,
9696
initial_rescale: float = 1.0,
9797
decrease_factor_eps: float = 0.8,
@@ -116,8 +116,10 @@ def __init__(
116116
attack.
117117
:param learning_rate_1: The learning rate applied for the first stage of the optimization of the attack.
118118
:param learning_rate_2: The learning rate applied for the second stage of the optimization of the attack.
119-
:param optimizer_1: The optimizer applied for the first stage of the optimization of the attack.
120-
:param optimizer_2: The optimizer applied for the second stage of the optimization of the attack.
119+
:param optimizer_1: The optimizer applied for the first stage of the optimization of the attack. If `None`
120+
attack will use `torch.optim.Adam`.
121+
:param optimizer_2: The optimizer applied for the second stage of the optimization of the attack. If `None`
122+
attack will use `torch.optim.Adam`.
121123
:param global_max_length: The length of the longest audio signal allowed by this attack.
122124
:param initial_rescale: Initial rescale coefficient to speedup the decrease of the perturbation size during
123125
the first stage of the optimization of the attack.
@@ -177,10 +179,16 @@ def __init__(
177179

178180
# Create the optimizers
179181
self._optimizer_arg_1 = optimizer_1
180-
self.optimizer_1 = self._optimizer_arg_1(params=[self.global_optimal_delta], lr=self.learning_rate_1)
182+
if self._optimizer_arg_1 is None:
183+
self.optimizer_1 = torch.optim.Adam(params=[self.global_optimal_delta], lr=self.learning_rate_1)
184+
else:
185+
self.optimizer_1 = self._optimizer_arg_1(params=[self.global_optimal_delta], lr=self.learning_rate_1)
181186

182187
self._optimizer_arg_2 = optimizer_2
183-
self.optimizer_2 = self._optimizer_arg_2(params=[self.global_optimal_delta], lr=self.learning_rate_2)
188+
if self._optimizer_arg_2 is None:
189+
self.optimizer_2 = torch.optim.Adam(params=[self.global_optimal_delta], lr=self.learning_rate_2)
190+
else:
191+
self.optimizer_2 = self._optimizer_arg_2(params=[self.global_optimal_delta], lr=self.learning_rate_2)
184192

185193
# Setup for AMP use
186194
if self._use_amp:
@@ -247,8 +255,15 @@ class only supports targeted attack.
247255
self.global_optimal_delta.data = torch.zeros(self.batch_size, self.global_max_length).type(torch.float64)
248256

249257
# Next, reset optimizers
250-
self.optimizer_1 = self._optimizer_arg_1(params=[self.global_optimal_delta], lr=self.learning_rate_1)
251-
self.optimizer_2 = self._optimizer_arg_2(params=[self.global_optimal_delta], lr=self.learning_rate_2)
258+
if self._optimizer_arg_1 is None:
259+
self.optimizer_1 = torch.optim.Adam(params=[self.global_optimal_delta], lr=self.learning_rate_1)
260+
else:
261+
self.optimizer_1 = self._optimizer_arg_1(params=[self.global_optimal_delta], lr=self.learning_rate_1)
262+
263+
if self._optimizer_arg_2 is None:
264+
self.optimizer_2 = torch.optim.Adam(params=[self.global_optimal_delta], lr=self.learning_rate_2)
265+
else:
266+
self.optimizer_2 = self._optimizer_arg_2(params=[self.global_optimal_delta], lr=self.learning_rate_2)
252267

253268
# Then compute the batch
254269
adv_x_batch = self._generate_batch(adv_x[batch_index_1:batch_index_2], y[batch_index_1:batch_index_2])

0 commit comments

Comments
 (0)