Skip to content

Commit c305db3

Browse files
authored
Merge pull request #931 from Trusted-AI/fix_asr_nan
Cast type of asr attack to float64 to fix the NaN issue
2 parents 52bdc5f + 43d9626 commit c305db3

File tree

2 files changed

+28
-27
lines changed

2 files changed

+28
-27
lines changed

art/attacks/evasion/imperceptible_asr/imperceptible_asr_pytorch.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
import scipy
3232

3333
from art.attacks.attack import EvasionAttack
34-
from art.config import ART_NUMPY_DTYPE
3534
from art.estimators.estimator import BaseEstimator, LossGradientsMixin, NeuralNetworkMixin
3635
from art.estimators.pytorch import PyTorchEstimator
3736
from art.estimators.speech_recognition.pytorch_deep_speech import PyTorchDeepSpeech
@@ -186,20 +185,12 @@ def __init__(
186185
self.optimizer_1st_stage = torch.optim.SGD(
187186
params=[self.global_optimal_delta], lr=self.learning_rate_1st_stage
188187
)
189-
else:
190-
self.optimizer_1st_stage = optimizer_1st_stage(
191-
params=[self.global_optimal_delta], lr=self.learning_rate_1st_stage
192-
)
193188

194189
self._optimizer_2nd_stage_arg = optimizer_2nd_stage
195190
if optimizer_2nd_stage is None:
196191
self.optimizer_2nd_stage = torch.optim.SGD(
197192
params=[self.global_optimal_delta], lr=self.learning_rate_2nd_stage
198193
)
199-
else:
200-
self.optimizer_2nd_stage = optimizer_2nd_stage(
201-
params=[self.global_optimal_delta], lr=self.learning_rate_2nd_stage
202-
)
203194

204195
# Setup for AMP use
205196
if self._use_amp:
@@ -242,7 +233,13 @@ class only supports targeted attack.
242233
)
243234

244235
# Start to compute adversarial examples
245-
adv_x = x.copy()
236+
dtype = x.dtype
237+
238+
# Cast to type float64 to avoid overflow
239+
if dtype.type == np.float64:
240+
adv_x = x.copy()
241+
else:
242+
adv_x = x.copy().astype(np.float64)
246243

247244
# Put the estimator in the training mode, otherwise CUDA can't backpropagate through the model.
248245
# However, estimator uses batch norm layers which need to be frozen
@@ -260,9 +257,9 @@ class only supports targeted attack.
260257
)
261258

262259
# First reset delta
263-
self.global_optimal_delta.data = torch.zeros(self.batch_size, self.global_max_length).type(torch.float32)
260+
self.global_optimal_delta.data = torch.zeros(self.batch_size, self.global_max_length).type(torch.float64)
264261

265-
# Next, reset non-SGD optimizers
262+
# Next, reset optimizers
266263
if self._optimizer_1st_stage_arg is not None:
267264
self.optimizer_1st_stage = self._optimizer_1st_stage_arg(
268265
params=[self.global_optimal_delta], lr=self.learning_rate_1st_stage
@@ -280,6 +277,11 @@ class only supports targeted attack.
280277

281278
# Unfreeze batch norm layers again
282279
self.estimator.set_batchnorm(train=True)
280+
281+
# Recast to the original type if needed
282+
if dtype.type == np.float32:
283+
adv_x = adv_x.astype(dtype)
284+
283285
return adv_x
284286

285287
def _generate_batch(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
@@ -317,7 +319,7 @@ class only supports targeted attack.
317319

318320
# Reset delta with new result
319321
local_batch_shape = successful_adv_input_1st_stage.shape
320-
self.global_optimal_delta.data = torch.zeros(self.batch_size, self.global_max_length).type(torch.float32)
322+
self.global_optimal_delta.data = torch.zeros(self.batch_size, self.global_max_length).type(torch.float64)
321323
self.global_optimal_delta.data[
322324
: local_batch_shape[0], : local_batch_shape[1]
323325
] = successful_perturbation_1st_stage
@@ -353,11 +355,11 @@ class only supports targeted attack.
353355
local_max_length = np.max(real_lengths)
354356

355357
# Initialize rescale
356-
rescale = np.ones([local_batch_size, local_max_length], dtype=ART_NUMPY_DTYPE) * self.initial_rescale
358+
rescale = np.ones([local_batch_size, local_max_length], dtype=np.float64) * self.initial_rescale
357359

358360
# Reformat input
359-
input_mask = np.zeros([local_batch_size, local_max_length], dtype=ART_NUMPY_DTYPE)
360-
original_input = np.zeros([local_batch_size, local_max_length], dtype=ART_NUMPY_DTYPE)
361+
input_mask = np.zeros([local_batch_size, local_max_length], dtype=np.float64)
362+
original_input = np.zeros([local_batch_size, local_max_length], dtype=np.float64)
361363

362364
for local_batch_size_idx in range(local_batch_size):
363365
input_mask[local_batch_size_idx, : len(x[local_batch_size_idx])] = 1
@@ -521,12 +523,12 @@ class only supports targeted attack.
521523
local_max_length = np.max(real_lengths)
522524

523525
# Initialize alpha and rescale
524-
alpha = np.array([self.initial_alpha] * local_batch_size, dtype=ART_NUMPY_DTYPE)
525-
rescale = np.ones([local_batch_size, local_max_length], dtype=ART_NUMPY_DTYPE) * self.initial_rescale
526+
alpha = np.array([self.initial_alpha] * local_batch_size, dtype=np.float64)
527+
rescale = np.ones([local_batch_size, local_max_length], dtype=np.float64) * self.initial_rescale
526528

527529
# Reformat input
528-
input_mask = np.zeros([local_batch_size, local_max_length], dtype=ART_NUMPY_DTYPE)
529-
original_input = np.zeros([local_batch_size, local_max_length], dtype=ART_NUMPY_DTYPE)
530+
input_mask = np.zeros([local_batch_size, local_max_length], dtype=np.float64)
531+
original_input = np.zeros([local_batch_size, local_max_length], dtype=np.float64)
530532

531533
for local_batch_size_idx in range(local_batch_size):
532534
input_mask[local_batch_size_idx, : len(x[local_batch_size_idx])] = 1
@@ -678,7 +680,7 @@ def _compute_masking_threshold(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndar
678680
barks = 13 * np.arctan(0.00076 * freqs) + 3.5 * np.arctan(pow(freqs / 7500.0, 2))
679681

680682
# Compute quiet threshold
681-
ath = np.zeros(len(barks), dtype=ART_NUMPY_DTYPE) - np.inf
683+
ath = np.zeros(len(barks), dtype=np.float64) - np.inf
682684
bark_idx = np.argmax(barks > 1)
683685
ath[bark_idx:] = (
684686
3.64 * pow(freqs[bark_idx:] * 0.001, -0.8)
@@ -700,7 +702,7 @@ def _compute_masking_threshold(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndar
700702
if len(psd[:, i]) - 1 in masker_idx:
701703
masker_idx = np.delete(masker_idx, len(psd[:, i]) - 1)
702704

703-
barks_psd = np.zeros([len(masker_idx), 3], dtype=ART_NUMPY_DTYPE)
705+
barks_psd = np.zeros([len(masker_idx), 3], dtype=np.float64)
704706
barks_psd[:, 0] = barks[masker_idx]
705707
barks_psd[:, 1] = 10 * np.log10(
706708
pow(10, psd[:, i][masker_idx - 1] / 10.0)
@@ -742,7 +744,7 @@ def _compute_masking_threshold(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndar
742744
for m in range(barks_psd.shape[0]):
743745
d_z = barks - barks_psd[m, 0]
744746
zero_idx = np.argmax(d_z > 0)
745-
s_f = np.zeros(len(d_z), dtype=ART_NUMPY_DTYPE)
747+
s_f = np.zeros(len(d_z), dtype=np.float64)
746748
s_f[:zero_idx] = 27 * d_z[:zero_idx]
747749
s_f[zero_idx:] = (-27 + 0.37 * max(barks_psd[m, 1] - 40, 0)) * d_z[zero_idx:]
748750
t_s.append(barks_psd[m, 1] + delta[m] + s_f)
@@ -764,7 +766,6 @@ def _psd_transform(self, delta: "torch.Tensor", original_max_psd: "torch.Tensor"
764766
:return: The psd matrix.
765767
"""
766768
import torch # lgtm [py/repeated-import]
767-
import torchaudio
768769

769770
# These parameters are needed for the transformation
770771
sample_rate = self.estimator.model.audio_conf.sample_rate
@@ -788,7 +789,7 @@ def _psd_transform(self, delta: "torch.Tensor", original_max_psd: "torch.Tensor"
788789
else:
789790
raise NotImplementedError("Spectrogram window %s not supported." % window)
790791

791-
# return STFT of delta
792+
# Return STFT of delta
792793
delta_stft = torch.stft(
793794
delta,
794795
n_fft=n_fft,
@@ -798,7 +799,7 @@ def _psd_transform(self, delta: "torch.Tensor", original_max_psd: "torch.Tensor"
798799
window=window_fn(win_length).to(self.estimator.device),
799800
).to(self.estimator.device)
800801

801-
# take abs of complex STFT results
802+
# Take abs of complex STFT results
802803
transformed_delta = torch.sqrt(torch.sum(torch.square(delta_stft), -1))
803804

804805
# Compute the psd matrix

tests/attacks/evasion/test_imperceptible_asr_pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_imperceptible_asr_pytorch(art_warning, expected_values, use_amp, device
3838

3939
from art.estimators.speech_recognition.pytorch_deep_speech import PyTorchDeepSpeech
4040
from art.attacks.evasion.imperceptible_asr.imperceptible_asr_pytorch import ImperceptibleASRPyTorch
41-
from art.defences.preprocessor import LFilterPyTorch
41+
from art.preprocessing.audio import LFilterPyTorch
4242

4343
try:
4444
# Load data for testing

0 commit comments

Comments
 (0)