Skip to content

Commit 5ed7e20

Browse files
authored
Merge pull request #1198 from Trusted-AI/fix_asr_batch
ASR attack supports batch size larger than 1
2 parents 08c92be + 58f20a7 commit 5ed7e20

File tree

5 files changed

+227
-178
lines changed

5 files changed

+227
-178
lines changed

.github/actions/deepspeech-v2/run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#!/bin/sh -l
1+
#!/bin/bash
22

33
exit_code=0
44

.github/actions/deepspeech-v3/run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#!/bin/sh -l
1+
#!/bin/bash
22

33
exit_code=0
44

art/attacks/evasion/imperceptible_asr/imperceptible_asr_pytorch.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -317,15 +317,12 @@ class only supports targeted attack.
317317
theta_batch = []
318318
original_max_psd_batch = []
319319

320-
for i in range(len(x)):
321-
theta, original_max_psd = self._compute_masking_threshold(original_input[i])
320+
for _, x_i in enumerate(x):
321+
theta, original_max_psd = self._compute_masking_threshold(x_i)
322322
theta = theta.transpose(1, 0)
323323
theta_batch.append(theta)
324324
original_max_psd_batch.append(original_max_psd)
325325

326-
theta_batch = np.array(theta_batch)
327-
original_max_psd_batch = np.array(original_max_psd_batch)
328-
329326
# Reset delta with new result
330327
local_batch_shape = successful_adv_input_1st_stage.shape
331328
self.global_optimal_delta.data = torch.zeros(self.batch_size, self.global_max_length).type(torch.float64)
@@ -485,7 +482,7 @@ def _forward_1st_stage(
485482
return loss, local_delta, decoded_output, masked_adv_input, local_delta_rescale
486483

487484
def _attack_2nd_stage(
488-
self, x: np.ndarray, y: np.ndarray, theta_batch: np.ndarray, original_max_psd_batch: np.ndarray
485+
self, x: np.ndarray, y: np.ndarray, theta_batch: List[np.ndarray], original_max_psd_batch: List[np.ndarray]
489486
) -> "torch.Tensor":
490487
"""
491488
The second stage of the attack.
@@ -544,6 +541,7 @@ class only supports targeted attack.
544541
local_delta_rescale=local_delta_rescale,
545542
theta_batch=theta_batch,
546543
original_max_psd_batch=original_max_psd_batch,
544+
real_lengths=real_lengths,
547545
)
548546

549547
# Total loss
@@ -597,15 +595,17 @@ class only supports targeted attack.
597595
def _forward_2nd_stage(
598596
self,
599597
local_delta_rescale: "torch.Tensor",
600-
theta_batch: np.ndarray,
601-
original_max_psd_batch: np.ndarray,
598+
theta_batch: List[np.ndarray],
599+
original_max_psd_batch: List[np.ndarray],
600+
real_lengths: np.ndarray,
602601
) -> "torch.Tensor":
603602
"""
604603
The forward pass of the second stage of the attack.
605604
606605
:param local_delta_rescale: Local delta after rescaled.
607606
:param theta_batch: Original thresholds.
608607
:param original_max_psd_batch: Original maximum psd.
608+
:param real_lengths: Real lengths of original sequences.
609609
:return: The loss tensor of the second stage of the attack.
610610
"""
611611
import torch # lgtm [py/repeated-import]
@@ -616,7 +616,7 @@ def _forward_2nd_stage(
616616

617617
for i, _ in enumerate(theta_batch):
618618
psd_transform_delta = self._psd_transform(
619-
delta=local_delta_rescale[i, :], original_max_psd=original_max_psd_batch[i]
619+
delta=local_delta_rescale[i, : real_lengths[i]], original_max_psd=original_max_psd_batch[i]
620620
)
621621

622622
loss = torch.mean(relu(psd_transform_delta - torch.tensor(theta_batch[i]).to(self.estimator.device)))

0 commit comments

Comments
 (0)