@@ -481,15 +481,20 @@ def test_specaugment(self, n_time_masks, time_mask_param, n_freq_masks, freq_mas
481481
482482 @parameterized .expand (
483483 [
484- ((32000 ,), (0 ,), 16000 ),
485- ((1 , 32000 ), (1 , 0 ), 32000 ),
486- ((2 , 44100 ), (2 , 0 ), 32000 ),
487- ((2 , 2 , 44100 ), (2 , 2 , 0 ), 32000 ),
484+ ((32000 ,), (0 ,), 16000 , 0.0 ),
485+ ((1 , 32000 ), (1 , 0 ), 32000 , 0.0 ),
486+ ((2 , 44100 ), (2 , 0 ), 32000 , 0.0 ),
487+ ((2 , 2 , 44100 ), (2 , 2 , 0 ), 32000 , 0.0 ),
488+ ((32000 ,), (16000 ,), 16000 , 1.0 ),
489+ ((32000 ,), (32000 ,), 16000 , 4.0 ),
490+ ((1 , 32000 ), (1 , 32000 ), 32000 , 1.0 ),
491+ ((2 , 44100 ), (2 , 32000 ), 32000 , 1.0 ),
492+ ((2 , 2 , 44100 ), (2 , 2 , 32000 ), 32000 , 1.0 ),
488493 ]
489494 )
490- def test_vad_on_zero_audio (self , input_shape , output_shape , sample_rate : int ):
491- """VAD should return zero when input is zero Tensor"""
495+ def test_vad_on_zero_audio (self , input_shape , output_shape , sample_rate : int , pre_trigger_time : float ):
496+ """VAD should return zero when input is zero Tensor when pre_trigger_time=0 """
492497 inpt = torch .zeros (input_shape , dtype = self .dtype , device = self .device )
493498 expected_output = torch .zeros (output_shape , dtype = self .dtype , device = self .device )
494- result = T .Vad (sample_rate )(inpt )
499+ result = T .Vad (sample_rate , pre_trigger_time = pre_trigger_time )(inpt )
495500 self .assertEqual (result , expected_output )
0 commit comments