Skip to content

Commit 682ad92

Browse files
authored
add topk #6 (#8)
1 parent e458f9d commit 682ad92

File tree

2 files changed

+52
-9
lines changed

2 files changed

+52
-9
lines changed

tests/test_gpu_augmentations.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,33 @@ def _noop_add_noise(waveforms, sample_rate, **kwargs):
138138
out_wave, out_lengths = aug(waveforms, lengths=lengths)
139139
assert out_wave.shape[0] == waveforms.shape[0]
140140
assert out_lengths.data_ptr() == lengths.data_ptr()
141+
142+
143+
def test_wav2aug_top_k_limits_ops(monkeypatch):
144+
"""Test that top_k limits the number of augmentations used."""
145+
146+
def _noop_add_noise(waveforms, loader, **kwargs):
147+
return waveforms
148+
149+
monkeypatch.setattr("wav2aug.gpu.wav2aug.add_noise", _noop_add_noise)
150+
151+
# top_k=3 should only include: noise, freq_drop, time_dropout
152+
aug = Wav2Aug(sample_rate=16_000, top_k=3)
153+
assert len(aug._base_ops) == 3
154+
155+
# top_k=6 should include: noise, freq_drop, time_dropout, speed_perturb, amp_clip, chunk_swap
156+
aug = Wav2Aug(sample_rate=16_000, top_k=6)
157+
assert len(aug._base_ops) == 6
158+
159+
# no pass should include all 9
160+
aug = Wav2Aug(sample_rate=16_000)
161+
assert len(aug._base_ops) == 9
162+
163+
164+
def test_wav2aug_top_k_invalid_raises():
165+
"""Test that invalid top_k values raise ValueError."""
166+
with pytest.raises(ValueError, match="top_k must be between 1 and 9"):
167+
Wav2Aug(sample_rate=16_000, top_k=0)
168+
169+
with pytest.raises(ValueError, match="top_k must be between 1 and 9"):
170+
Wav2Aug(sample_rate=16_000, top_k=10)

wav2aug/gpu/wav2aug.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __init__(
2222
sample_rate: int,
2323
noise_dir: str | None = None,
2424
noise_preload: bool = True,
25+
top_k: int = 9,
2526
) -> None:
2627
"""Initialize Wav2Aug.
2728
@@ -31,6 +32,11 @@ def __init__(
3132
default cached noise pack (auto-downloaded if needed).
3233
noise_preload: If True (default), preload all noise files into CPU RAM
3334
at initialization for fast sampling. If False, load files on-demand.
35+
top_k: Number of top augmentations to use, ordered by effectiveness.
36+
Default is 9 (all augmentations). Common values: 3, 6, or 9.
37+
Order (best to worst): Noise Addition, Freq Drop, Time Drop,
38+
Speed Perturb, Amp Clip, Chunk Swap, Babble Noise, Amp Scale,
39+
Polarity Inversion.
3440
"""
3541
self.sample_rate = int(sample_rate)
3642

@@ -41,24 +47,31 @@ def __init__(
4147
noise_dir = ensure_pack("pointsource_noises")
4248
self._noise_loader = NoiseLoader(noise_dir, sample_rate, preload=noise_preload)
4349

44-
self._base_ops: List[
45-
Callable[[torch.Tensor, torch.Tensor | None], torch.Tensor]
46-
] = [
50+
# All ops ordered by effectiveness (best first)
51+
all_ops: List[Callable[[torch.Tensor, torch.Tensor | None], torch.Tensor]] = [
52+
# top 3
4753
lambda x, lengths: add_noise(
4854
x, self._noise_loader, snr_low=0.0, snr_high=10.0
4955
),
50-
lambda x, lengths: add_babble_noise(x),
51-
lambda x, lengths: chunk_swap(x),
5256
lambda x, lengths: freq_drop(x),
53-
lambda x, lengths: invert_polarity(x),
54-
lambda x, lengths: rand_amp_clip(x),
55-
lambda x, lengths: rand_amp_scale(x),
56-
lambda x, lengths: speed_perturb(x, sample_rate=self.sample_rate),
5757
lambda x, lengths: time_dropout(
5858
x, sample_rate=self.sample_rate, lengths=lengths
5959
),
60+
# top 6
61+
lambda x, lengths: speed_perturb(x, sample_rate=self.sample_rate),
62+
lambda x, lengths: rand_amp_clip(x),
63+
lambda x, lengths: chunk_swap(x),
64+
# all 9
65+
lambda x, lengths: add_babble_noise(x),
66+
lambda x, lengths: rand_amp_scale(x),
67+
lambda x, lengths: invert_polarity(x),
6068
]
6169

70+
if top_k < 1 or top_k > len(all_ops):
71+
raise ValueError(f"top_k must be between 1 and {len(all_ops)}, got {top_k}")
72+
73+
self._base_ops = all_ops[:top_k]
74+
6275
@torch.no_grad()
6376
def __call__(
6477
self,

0 commit comments

Comments
 (0)