@@ -24,20 +24,16 @@ def _waveforms(
2424 return torch .randn (batch , time , device = DEVICE , dtype = dtype )
2525
2626
27- def test_rand_amp_clip_inplace_preserves_shape ():
27+ def test_rand_amp_clip_preserves_shape ():
2828 waveforms = _waveforms ()
29- ptr = waveforms .data_ptr ()
3029 out = rand_amp_clip (waveforms )
31- assert out .data_ptr () == ptr
3230 assert out .shape == (waveforms .size (0 ), waveforms .size (1 ))
3331 assert torch .isfinite (out ).all ()
3432
3533
36- def test_rand_amp_scale_inplace_preserves_shape ():
34+ def test_rand_amp_scale_preserves_shape ():
3735 waveforms = _waveforms ()
38- ptr = waveforms .data_ptr ()
3936 out = rand_amp_scale (waveforms )
40- assert out .data_ptr () == ptr
4137 assert out .shape == (waveforms .size (0 ), waveforms .size (1 ))
4238 assert torch .isfinite (out ).all ()
4339
@@ -55,11 +51,9 @@ def test_chunk_swap_outputs_permutation():
5551 )
5652
5753
58- def test_freq_drop_no_nan_and_inplace ():
54+ def test_freq_drop_no_nan ():
5955 waveforms = _waveforms ()
60- ptr = waveforms .data_ptr ()
6156 out = freq_drop (waveforms )
62- assert out .data_ptr () == ptr
6357 assert torch .isnan (out ).logical_not ().all ()
6458
6559
@@ -83,23 +77,19 @@ def test_add_noise_with_mock_loader():
8377 from unittest .mock import MagicMock
8478
8579 waveforms = torch .ones (2 , 128 , device = DEVICE , dtype = torch .float32 )
86- ptr = waveforms .data_ptr ()
8780
8881 # Create mock loader that returns zeros
8982 mock_loader = MagicMock ()
9083 mock_loader .get_batch .return_value = torch .zeros (2 , 128 )
9184
9285 out = add_noise (waveforms , mock_loader , snr_low = 0.0 , snr_high = 0.0 )
93- assert out .data_ptr () == ptr
9486 assert torch .isfinite (out ).all ()
9587 mock_loader .get_batch .assert_called_once_with (2 , 128 )
9688
9789
9890def test_add_babble_noise_identity_for_singleton_batch ():
9991 waveforms = torch .full ((1 , 64 ), 2.0 , device = DEVICE , dtype = torch .float32 )
100- ptr = waveforms .data_ptr ()
10192 out = add_babble_noise (waveforms , snr_low = 0.0 , snr_high = 0.0 )
102- assert out .data_ptr () == ptr
10393 assert torch .allclose (out , torch .full_like (out , 2.0 ))
10494
10595
@@ -123,7 +113,6 @@ def test_speed_perturb_adjusts_length():
123113def test_time_dropout_zeroes_segments ():
124114 waveforms = torch .ones (2 , 64 , device = DEVICE , dtype = torch .float32 )
125115 lengths = torch .ones (2 , device = DEVICE , dtype = torch .float32 )
126- ptr = waveforms .data_ptr ()
127116 out = time_dropout (
128117 waveforms ,
129118 lengths = lengths ,
@@ -132,7 +121,6 @@ def test_time_dropout_zeroes_segments():
132121 chunk_size_low = 2 ,
133122 chunk_size_high = 2 ,
134123 )
135- assert out .data_ptr () == ptr
136124 zeros_per_row = (out == 0 ).sum (dim = 1 )
137125 assert torch .all (zeros_per_row >= 2 )
138126
0 commit comments