@@ -481,15 +481,20 @@ def test_specaugment(self, n_time_masks, time_mask_param, n_freq_masks, freq_mas
481
481
482
482
@parameterized .expand (
483
483
[
484
- ((32000 ,), (0 ,), 16000 ),
485
- ((1 , 32000 ), (1 , 0 ), 32000 ),
486
- ((2 , 44100 ), (2 , 0 ), 32000 ),
487
- ((2 , 2 , 44100 ), (2 , 2 , 0 ), 32000 ),
484
+ ((32000 ,), (0 ,), 16000 , 0.0 ),
485
+ ((1 , 32000 ), (1 , 0 ), 32000 , 0.0 ),
486
+ ((2 , 44100 ), (2 , 0 ), 32000 , 0.0 ),
487
+ ((2 , 2 , 44100 ), (2 , 2 , 0 ), 32000 , 0.0 ),
488
+ ((32000 ,), (16000 ,), 16000 , 1.0 ),
489
+ ((32000 ,), (32000 ,), 16000 , 4.0 ),
490
+ ((1 , 32000 ), (1 , 32000 ), 32000 , 1.0 ),
491
+ ((2 , 44100 ), (2 , 32000 ), 32000 , 1.0 ),
492
+ ((2 , 2 , 44100 ), (2 , 2 , 32000 ), 32000 , 1.0 ),
488
493
]
489
494
)
490
- def test_vad_on_zero_audio (self , input_shape , output_shape , sample_rate : int ):
491
- """VAD should return zero when input is zero Tensor"""
495
+ def test_vad_on_zero_audio (self , input_shape , output_shape , sample_rate : int , pre_trigger_time : float ):
496
+ """VAD should return zero when input is zero Tensor when pre_trigger_time=0 """
492
497
inpt = torch .zeros (input_shape , dtype = self .dtype , device = self .device )
493
498
expected_output = torch .zeros (output_shape , dtype = self .dtype , device = self .device )
494
- result = T .Vad (sample_rate )(inpt )
499
+ result = T .Vad (sample_rate , pre_trigger_time = pre_trigger_time )(inpt )
495
500
self .assertEqual (result , expected_output )
0 commit comments