Skip to content

Commit 7ad55cf

Browse files
authored
improvements to wav2aug (#5)
* align w speechbrain * refactor chunk_swap, freq_drop, time_drop * allow for same augmentation selection * single augment test * fix circ * add noise workers * update clip, freq_drop, and noise * Revert "update clip, freq_drop, and noise" This reverts commit 362cb47. * Reapply "update clip, freq_drop, and noise" This reverts commit 00cceea. * fix torchaudio version #3 (#4) * added preload for add_noise * improve speed_pert and chunk_swap * clean up * fixed speed_pert bad gcd issue * formatting * fixed tests * update readme
1 parent e9b7053 commit 7ad55cf

15 files changed

+710
-296
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ dependencies = [
3030
"torch>=2.0.0",
3131
"torchaudio>=2.0.0",
3232
"torchcodec>=0.7.0",
33+
"tqdm>=4.0.0",
3334
]
3435

3536
[project.optional-dependencies]

readme.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ A minimalistic PyTorch-based audio augmentation library for speech and audio aug
66

77
## ⚙️ Features
88

9-
* **Minimal dependencies**: we only rely on PyTorch, torchcodec, and torchaudio.
109
* **9 core augmentations**: amplitude scaling/clipping, noise addition, frequency dropout, polarity inversion, chunk swapping, speed perturbation, time dropout, and babble noise.
1110
* **Simplicity**: just install and start augmenting!
1211
* **Randomness**: all stochastic ops use PyTorch RNGs. Set a single seed and be done, e.g. torch.manual_seed(0); torch.cuda.manual_seed_all(0)

tests/test_gpu_augmentations.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_chunk_swap_outputs_permutation():
4949
)
5050
reference = base.clone()
5151
out = chunk_swap(base)
52-
assert out.data_ptr() == base.data_ptr()
52+
assert out.shape == base.shape
5353
assert torch.allclose(
5454
torch.sort(out, dim=1).values, torch.sort(reference, dim=1).values
5555
)
@@ -74,7 +74,7 @@ def _stub_noise_like(ref, sample_rate, noise_dir):
7474
ptr = waveforms.data_ptr()
7575
out = add_noise(
7676
waveforms,
77-
sample_rate=16_000,
77+
16_000, # sample_rate as positional argument
7878
snr_low=0.0,
7979
snr_high=0.0,
8080
download=False,
@@ -103,8 +103,9 @@ def test_speed_perturb_adjusts_length():
103103
waveforms = torch.linspace(
104104
0, 1, steps=200, device=DEVICE, dtype=torch.float32
105105
).repeat(2, 1)
106-
out = speed_perturb(waveforms, 16000, speed_changes=(0.5,))
107-
expected_len = int(round(200 * 1 / 0.5))
106+
out = speed_perturb(waveforms, 16000, speeds=(50,))
107+
# speed=50% → ratio=2.0 → 2x samples (slower)
108+
expected_len = int(200 * 2.0)
108109
assert out.shape == (2, expected_len)
109110

110111

uv.lock

Lines changed: 93 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

wav2aug/data/fetch.py

Lines changed: 13 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
import logging
66
import os
77
import pathlib
8-
import shutil
9-
import sys
108
import tarfile
119
import tempfile
1210
import time
1311
import urllib.request
1412
from pathlib import Path
1513
from urllib.parse import urlparse
1614

15+
from tqdm import tqdm
16+
1717
try:
1818
import fcntl
1919

@@ -59,51 +59,25 @@ def _safe_extract_tar_gz(tgz_path: str, dest_dir: str) -> None:
5959

6060
def _download(url: str, out_path: str) -> None:
6161
"""Download with simple progress bar to stderr."""
62-
show = os.environ.get("WAV2AUG_PROGRESS", "1") != "0"
63-
64-
def _progress(done: int, total: int):
65-
w = max(10, min(40, shutil.get_terminal_size(fallback=(80, 20)).columns - 30))
66-
if total > 0:
67-
pct = done / total
68-
fill = int(pct * w)
69-
bar = "#" * fill + "." * (w - fill)
70-
sys.stderr.write(
71-
f"\rwav2aug - Progress [{bar}] {pct*100:5.1f}% {done/1e6:6.1f}MB/{total/1e6:6.1f}MB"
72-
)
73-
else:
74-
sys.stderr.write(f"\rwav2aug - Progress {done/1e6:6.1f}MB")
75-
sys.stderr.flush()
76-
7762
name = Path(urlparse(url).path).name or "download"
7863

79-
sys.stderr.write(f"wav2aug - Downloading: {name}\n")
80-
sys.stderr.flush()
81-
8264
req = urllib.request.Request(url, headers={"User-Agent": "wav2aug/1.0"})
8365
start = time.monotonic()
8466
with urllib.request.urlopen(req) as r, open(out_path, "wb") as f:
8567
total = int(r.headers.get("Content-Length") or 0)
8668
chunk = 1 << 20
8769
done = 0
88-
last = time.monotonic()
89-
tty = show and sys.stderr.isatty()
90-
if tty:
91-
_progress(0, total)
92-
93-
while True:
94-
buf = r.read(chunk)
95-
if not buf:
96-
break
97-
f.write(buf)
98-
done += len(buf)
99-
if tty and (time.monotonic() - last) >= 0.05:
100-
_progress(done, total)
101-
last = time.monotonic()
102-
103-
if tty:
104-
_progress(done, total)
105-
sys.stderr.write("\n")
106-
sys.stderr.flush()
70+
71+
with tqdm(
72+
total=total, desc=f"Downloading {name}", unit="B", unit_scale=True
73+
) as pbar:
74+
while True:
75+
buf = r.read(chunk)
76+
if not buf:
77+
break
78+
f.write(buf)
79+
done += len(buf)
80+
pbar.update(len(buf))
10781

10882
elapsed = max(1e-6, time.monotonic() - start)
10983
log.info(

wav2aug/gpu/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from .amplitude_scaling import rand_amp_scale
33
from .chunk_swapping import chunk_swap
44
from .frequency_dropout import freq_drop
5-
from .noise_addition import add_babble_noise, add_noise
5+
from .noise_addition import NoiseLoader, add_babble_noise, add_noise
66
from .polarity_inversion import invert_polarity
77
from .speed_perturbation import speed_perturb
88
from .time_dropout import time_dropout
@@ -15,6 +15,7 @@
1515
"freq_drop",
1616
"add_noise",
1717
"add_babble_noise",
18+
"NoiseLoader",
1819
"invert_polarity",
1920
"speed_perturb",
2021
"time_dropout",

wav2aug/gpu/amplitude_clipping.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ def rand_amp_clip(
1313
) -> torch.Tensor:
1414
"""Random amplitude clipping for batched waveforms.
1515
16+
Normalizes each waveform to [-1, 1], applies clipping, then restores
17+
the original amplitude scaled by the clip factor.
18+
1619
Args:
1720
waveforms: Tensor of shape [batch, time].
1821
clip_low: Minimum clipping threshold as a fraction of peak.
@@ -30,19 +33,22 @@ def rand_amp_clip(
3033

3134
device = waveforms.device
3235
dtype = waveforms.dtype
33-
peaks = waveforms.abs().amax(dim=1, keepdim=True).clamp_min(1.0)
34-
normalized = waveforms / peaks
3536

36-
# Per-sample clip thresholds
37-
clip = torch.rand((waveforms.size(0), 1), device=device, dtype=dtype)
37+
# Normalize to [-1, 1] by absolute max
38+
abs_max = waveforms.abs().amax(dim=1, keepdim=True)
39+
abs_max = abs_max.clamp_min(eps)
40+
waveforms.div_(abs_max)
41+
42+
# Single clip value for entire batch (matches SpeechBrain)
43+
clip = torch.rand(1, device=device, dtype=dtype)
3844
clip = clip * (clip_high - clip_low) + clip_low
3945
clip = clip.clamp_min(eps)
4046

41-
normalized = torch.minimum(normalized, clip)
42-
normalized = torch.maximum(normalized, -clip)
47+
# Apply clipping
48+
waveforms.clamp_(-clip, clip)
4349

44-
scale = peaks / clip
45-
waveforms.copy_(normalized * scale)
50+
# Restore amplitude scaled by clip factor
51+
waveforms.mul_(abs_max / clip)
4652
return waveforms
4753

4854

wav2aug/gpu/amplitude_scaling.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ def rand_amp_scale(
1212
) -> torch.Tensor:
1313
"""Random amplitude scaling for batched waveforms.
1414
15+
Normalizes each waveform to [-1, 1] then applies a random amplitude
16+
scale factor.
17+
1518
Args:
1619
waveforms: Tensor of shape [batch, time].
1720
amp_low: Minimum amplitude scale factor.
@@ -28,12 +31,17 @@ def rand_amp_scale(
2831

2932
device = waveforms.device
3033
dtype = waveforms.dtype
31-
denom = waveforms.abs().amax(dim=1, keepdim=True).clamp_min(1.0)
34+
35+
# Normalize to [-1, 1] by dividing by absolute max
36+
abs_max = waveforms.abs().amax(dim=1, keepdim=True)
37+
# Avoid division by zero for silent signals
38+
abs_max = abs_max.clamp_min(1e-14)
39+
waveforms.div_(abs_max)
3240

3341
# Per-sample scaling factors
3442
scales = torch.rand((waveforms.size(0), 1), device=device, dtype=dtype)
3543
scales = scales * (amp_high - amp_low) + amp_low
36-
waveforms.mul_(scales / denom)
44+
waveforms.mul_(scales)
3745
return waveforms
3846

3947

0 commit comments

Comments
 (0)