Skip to content

Commit 4852b93

Browse files
authored
removed in-place (#15)
1 parent 1e52551 commit 4852b93

File tree

7 files changed

+20
-41
lines changed

7 files changed

+20
-41
lines changed

tests/test_gpu_augmentations.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,16 @@ def _waveforms(
2424
return torch.randn(batch, time, device=DEVICE, dtype=dtype)
2525

2626

27-
def test_rand_amp_clip_inplace_preserves_shape():
27+
def test_rand_amp_clip_preserves_shape():
2828
waveforms = _waveforms()
29-
ptr = waveforms.data_ptr()
3029
out = rand_amp_clip(waveforms)
31-
assert out.data_ptr() == ptr
3230
assert out.shape == (waveforms.size(0), waveforms.size(1))
3331
assert torch.isfinite(out).all()
3432

3533

36-
def test_rand_amp_scale_inplace_preserves_shape():
34+
def test_rand_amp_scale_preserves_shape():
3735
waveforms = _waveforms()
38-
ptr = waveforms.data_ptr()
3936
out = rand_amp_scale(waveforms)
40-
assert out.data_ptr() == ptr
4137
assert out.shape == (waveforms.size(0), waveforms.size(1))
4238
assert torch.isfinite(out).all()
4339

@@ -55,11 +51,9 @@ def test_chunk_swap_outputs_permutation():
5551
)
5652

5753

58-
def test_freq_drop_no_nan_and_inplace():
54+
def test_freq_drop_no_nan():
5955
waveforms = _waveforms()
60-
ptr = waveforms.data_ptr()
6156
out = freq_drop(waveforms)
62-
assert out.data_ptr() == ptr
6357
assert torch.isnan(out).logical_not().all()
6458

6559

@@ -83,23 +77,19 @@ def test_add_noise_with_mock_loader():
8377
from unittest.mock import MagicMock
8478

8579
waveforms = torch.ones(2, 128, device=DEVICE, dtype=torch.float32)
86-
ptr = waveforms.data_ptr()
8780

8881
# Create mock loader that returns zeros
8982
mock_loader = MagicMock()
9083
mock_loader.get_batch.return_value = torch.zeros(2, 128)
9184

9285
out = add_noise(waveforms, mock_loader, snr_low=0.0, snr_high=0.0)
93-
assert out.data_ptr() == ptr
9486
assert torch.isfinite(out).all()
9587
mock_loader.get_batch.assert_called_once_with(2, 128)
9688

9789

9890
def test_add_babble_noise_identity_for_singleton_batch():
9991
waveforms = torch.full((1, 64), 2.0, device=DEVICE, dtype=torch.float32)
100-
ptr = waveforms.data_ptr()
10192
out = add_babble_noise(waveforms, snr_low=0.0, snr_high=0.0)
102-
assert out.data_ptr() == ptr
10393
assert torch.allclose(out, torch.full_like(out, 2.0))
10494

10595

@@ -123,7 +113,6 @@ def test_speed_perturb_adjusts_length():
123113
def test_time_dropout_zeroes_segments():
124114
waveforms = torch.ones(2, 64, device=DEVICE, dtype=torch.float32)
125115
lengths = torch.ones(2, device=DEVICE, dtype=torch.float32)
126-
ptr = waveforms.data_ptr()
127116
out = time_dropout(
128117
waveforms,
129118
lengths=lengths,
@@ -132,7 +121,6 @@ def test_time_dropout_zeroes_segments():
132121
chunk_size_low=2,
133122
chunk_size_high=2,
134123
)
135-
assert out.data_ptr() == ptr
136124
zeros_per_row = (out == 0).sum(dim=1)
137125
assert torch.all(zeros_per_row >= 2)
138126

wav2aug/gpu/amplitude_clipping.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def rand_amp_clip(
2323
eps: Numerical floor to avoid division by zero.
2424
2525
Returns:
26-
The input ``waveforms`` tensor, modified in-place.
26+
Clipped waveforms.
2727
"""
2828
if waveforms.ndim != 2:
2929
raise AssertionError("expected waveforms shaped [batch, time]")
@@ -37,19 +37,19 @@ def rand_amp_clip(
3737
# Normalize to [-1, 1] by absolute max
3838
abs_max = waveforms.abs().amax(dim=1, keepdim=True)
3939
abs_max = abs_max.clamp_min(eps)
40-
waveforms.div_(abs_max)
40+
out = waveforms / abs_max
4141

4242
# Single clip value for entire batch (matches SpeechBrain)
4343
clip = torch.rand(1, device=device, dtype=dtype)
4444
clip = clip * (clip_high - clip_low) + clip_low
4545
clip = clip.clamp_min(eps)
4646

4747
# Apply clipping
48-
waveforms.clamp_(-clip, clip)
48+
out = out.clamp(-clip, clip)
4949

5050
# Restore amplitude scaled by clip factor
51-
waveforms.mul_(abs_max / clip)
52-
return waveforms
51+
out = out * (abs_max / clip)
52+
return out
5353

5454

5555
__all__ = ["rand_amp_clip"]

wav2aug/gpu/amplitude_scaling.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def rand_amp_scale(
2121
amp_high: Maximum amplitude scale factor.
2222
2323
Returns:
24-
The input ``waveforms`` tensor, modified in-place.
24+
Scaled waveforms.
2525
"""
2626
if waveforms.ndim != 2:
2727
raise AssertionError("expected waveforms shaped [batch, time]")
@@ -36,13 +36,12 @@ def rand_amp_scale(
3636
abs_max = waveforms.abs().amax(dim=1, keepdim=True)
3737
# Avoid division by zero for silent signals
3838
abs_max = abs_max.clamp_min(1e-14)
39-
waveforms.div_(abs_max)
39+
out = waveforms / abs_max
4040

4141
# Per-sample scaling factors
4242
scales = torch.rand((waveforms.size(0), 1), device=device, dtype=dtype)
4343
scales = scales * (amp_high - amp_low) + amp_low
44-
waveforms.mul_(scales)
45-
return waveforms
44+
return out * scales
4645

4746

4847
__all__ = ["rand_amp_scale"]

wav2aug/gpu/frequency_dropout.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,10 +188,9 @@ def freq_drop(
188188
dropped = dropped.squeeze(-1)
189189

190190
if clamp_abs is not None and clamp_abs > 0:
191-
dropped = dropped.clamp_(-clamp_abs, clamp_abs)
191+
dropped = dropped.clamp(-clamp_abs, clamp_abs)
192192

193-
waveforms.copy_(dropped.to(dtype))
194-
return waveforms
193+
return dropped.to(dtype)
195194

196195

197196
__all__ = ["freq_drop"]

wav2aug/gpu/noise_addition.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -167,14 +167,9 @@ def _mix_noise(
167167
signal_rms = waveforms.pow(2).mean(dim=1, keepdim=True).sqrt().clamp_min(_EPS)
168168
noise_rms = noise.pow(2).mean(dim=1, keepdim=True).sqrt().clamp_min(_EPS)
169169

170-
# Scale the clean signal by (1 - noise_amplitude_factor)
171-
waveforms.mul_(1.0 - noise_amplitude_factor)
172-
173-
# Compute target noise amplitude and scale noise accordingly
170+
# Mix signal and noise at target SNR
174171
noise_scale = (noise_amplitude_factor * signal_rms) / noise_rms
175-
waveforms.add_(noise * noise_scale)
176-
177-
return waveforms
172+
return waveforms * (1.0 - noise_amplitude_factor) + noise * noise_scale
178173

179174

180175
@torch.no_grad()

wav2aug/gpu/polarity_inversion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ def invert_polarity(
3232

3333
batch = waveforms.size(0)
3434

35+
# Build a per-sample sign multiplier: -1 for flipped, +1 for kept
3536
flips = torch.rand(batch, device=waveforms.device) < prob
36-
if flips.any():
37-
waveforms[flips] *= -1
38-
return waveforms
37+
signs = torch.where(flips, -1.0, 1.0).unsqueeze(1)
38+
return waveforms * signs
3939

4040

4141
__all__ = ["invert_polarity"]

wav2aug/gpu/time_dropout.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def time_dropout(
5656
base_sample_rate: Reference sample rate for scaling chunk lengths.
5757
5858
Returns:
59-
Waveforms with time dropout applied (in-place modification).
59+
Waveforms with time dropout applied.
6060
6161
Raises:
6262
AssertionError: If waveforms are not 2D.
@@ -150,9 +150,7 @@ def time_dropout(
150150
drop_mask = chunk_mask.any(dim=1) # [B, T]
151151

152152
# Zero out masked positions
153-
waveforms.masked_fill_(drop_mask, 0.0)
154-
155-
return waveforms
153+
return waveforms.masked_fill(drop_mask, 0.0)
156154

157155

158156
__all__ = ["time_dropout"]

0 commit comments

Comments
 (0)