Skip to content

Commit 41f23b8

Browse files
authored
Merge pull request #635 from Trusted-AI/fix_ds_cornercase
Fix deepspeech estimator cornercase
2 parents 4ddb72f + bcd1b3c commit 41f23b8

File tree

3 files changed

+23
-8
lines changed

3 files changed

+23
-8
lines changed

art/attacks/evasion/imperceptible_asr/imperceptible_asr_pytorch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,8 @@ def __init__(
145145
:param opt_level: Specify a pure or mixed precision optimization level. Used when use_amp is True. Accepted
146146
values are `O0`, `O1`, `O2`, and `O3`.
147147
:param loss_scale: Loss scaling. Used when use_amp is True. Default is 1.0 due to warp-ctc not supporting
148-
scaling of gradients.
148+
scaling of gradients. If passed as a string, must be a string representing a number,
149+
e.g., “1.0”, or the string “dynamic”.
149150
"""
150151
import torch # lgtm [py/repeated-import]
151152
from torch.autograd import Variable

art/estimators/speech_recognition/pytorch_deep_speech.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ def __init__(
9494
:param opt_level: Specify a pure or mixed precision optimization level. Used when use_amp is True. Accepted
9595
values are `O0`, `O1`, `O2`, and `O3`.
9696
:param loss_scale: Loss scaling. Used when use_amp is True. Default is 1.0 due to warp-ctc not supporting
97-
scaling of gradients.
97+
scaling of gradients. If passed as a string, must be a string representing a number,
98+
e.g., “1.0”, or the string “dynamic”.
9899
:param decoder_type: Decoder type. Either `greedy` or `beam`. This parameter is only used when users want
99100
transcription outputs.
100101
:param lm_path: Path to an (optional) kenlm language model for use with beam search. This parameter is only
@@ -285,7 +286,7 @@ def predict(
285286
"""
286287
import torch # lgtm [py/repeated-import]
287288

288-
x_ = x.copy()
289+
x_ = np.array([x_i for x_i in x] + [np.array([0.1]), np.array([0.1, 0.2])])[:-2]
289290

290291
# Put the model in the eval mode
291292
self._model.eval()
@@ -370,7 +371,7 @@ def loss_gradient(self, x: np.ndarray, y: np.ndarray, **kwargs) -> np.ndarray:
370371
"""
371372
from warpctc_pytorch import CTCLoss
372373

373-
x_ = x.copy()
374+
x_ = np.array([x_i for x_i in x] + [np.array([0.1]), np.array([0.1, 0.2])])[:-2]
374375

375376
# Put the model in the training mode
376377
self._model.train()
@@ -432,8 +433,6 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in
432433
"""
433434
import random
434435

435-
import torch # lgtm [py/repeated-import]
436-
437436
from warpctc_pytorch import CTCLoss
438437

439438
# Put the model in the training mode
@@ -466,8 +465,10 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in
466465
)
467466

468467
# Extract random batch
469-
i_batch = x_preprocessed[ind[begin:end]].copy()
470-
o_batch = y_preprocessed[ind[begin:end]].copy()
468+
i_batch = np.array(
469+
[x_i for x_i in x_preprocessed[ind[begin : end]]] + [np.array([0.1]), np.array([0.1, 0.2])]
470+
)[:-2]
471+
o_batch = y_preprocessed[ind[begin : end]]
471472

472473
# Transform data into the model input space
473474
inputs, targets, input_rates, target_sizes, batch_idx = self.transform_model_input(

tests/estimators/speech_recognition/test_pytorch_deep_speech.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,19 @@ def _test_all(self, request, setup_class):
183183
expected_transcriptions = np.array(["", "", ""])
184184
assert (expected_transcriptions == transcriptions).all()
185185

186+
# Test transcription outputs, corner case
187+
if request.param is True:
188+
transcriptions = self.speech_recognizer_amp.predict(
189+
np.array([self.x[0]]), batch_size=2, transcription_output=True
190+
)
191+
else:
192+
transcriptions = self.speech_recognizer.predict(
193+
np.array([self.x[0]]), batch_size=2, transcription_output=True
194+
)
195+
196+
expected_transcriptions = np.array([""])
197+
assert (expected_transcriptions == transcriptions).all()
198+
186199
# Now test loss gradients
187200
# Create labels
188201
y = np.array(["SIX", "HI", "GOOD"])

0 commit comments

Comments
 (0)