Skip to content

Commit b6d4675

Browse files
authored
Fix vad return zero output when nonzero pre_trigger_time is requested
Differential Revision: D67532573 Pull Request resolved: #3866
1 parent a6b0a14 commit b6d4675

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

src/torchaudio/functional/filtering.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1662,8 +1662,8 @@ def vad(
16621662
flushedLen_ns = (measures_len - num_measures_to_flush) * measure_period_ns
16631663
break
16641664
# end for window
1665-
if not has_triggered:
1666-
return waveform[..., :0].view(shape[:-1] + torch.Size([0]))
1665+
if not has_triggered and shape[-1] >= fixed_pre_trigger_len_ns:
1666+
return waveform[..., :fixed_pre_trigger_len_ns].view(shape[:-1] + torch.Size([fixed_pre_trigger_len_ns]))
16671667

16681668
res = waveform[:, max(pos - samplesLen_ns + flushedLen_ns, 0) :]
16691669
# unpack batch

test/torchaudio_unittest/transforms/transforms_test_impl.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -481,15 +481,20 @@ def test_specaugment(self, n_time_masks, time_mask_param, n_freq_masks, freq_mas
481481

482482
@parameterized.expand(
483483
[
484-
((32000,), (0,), 16000),
485-
((1, 32000), (1, 0), 32000),
486-
((2, 44100), (2, 0), 32000),
487-
((2, 2, 44100), (2, 2, 0), 32000),
484+
((32000,), (0,), 16000, 0.0),
485+
((1, 32000), (1, 0), 32000, 0.0),
486+
((2, 44100), (2, 0), 32000, 0.0),
487+
((2, 2, 44100), (2, 2, 0), 32000, 0.0),
488+
((32000,), (16000,), 16000, 1.0),
489+
((32000,), (32000,), 16000, 4.0),
490+
((1, 32000), (1, 32000), 32000, 1.0),
491+
((2, 44100), (2, 32000), 32000, 1.0),
492+
((2, 2, 44100), (2, 2, 32000), 32000, 1.0),
488493
]
489494
)
490-
def test_vad_on_zero_audio(self, input_shape, output_shape, sample_rate: int):
491-
"""VAD should return zero when input is zero Tensor"""
495+
def test_vad_on_zero_audio(self, input_shape, output_shape, sample_rate: int, pre_trigger_time: float):
496+
"""VAD should return zero when input is zero Tensor when pre_trigger_time=0"""
492497
inpt = torch.zeros(input_shape, dtype=self.dtype, device=self.device)
493498
expected_output = torch.zeros(output_shape, dtype=self.dtype, device=self.device)
494-
result = T.Vad(sample_rate)(inpt)
499+
result = T.Vad(sample_rate, pre_trigger_time=pre_trigger_time)(inpt)
495500
self.assertEqual(result, expected_output)

0 commit comments

Comments
 (0)