@@ -22,6 +22,7 @@ def __init__(
2222 sample_rate : int ,
2323 noise_dir : str | None = None ,
2424 noise_preload : bool = True ,
25+ top_k : int = 9 ,
2526 ) -> None :
2627 """Initialize Wav2Aug.
2728
@@ -31,6 +32,11 @@ def __init__(
3132 default cached noise pack (auto-downloaded if needed).
3233 noise_preload: If True (default), preload all noise files into CPU RAM
3334 at initialization for fast sampling. If False, load files on-demand.
35+ top_k: Number of top augmentations to use, ordered by effectiveness.
36+ Default is 9 (all augmentations). Common values: 3, 6, or 9.
37+ Order (best to worst): Noise Addition, Freq Drop, Time Drop,
38+ Speed Perturb, Amp Clip, Chunk Swap, Babble Noise, Amp Scale,
39+ Polarity Inversion.
3440 """
3541 self .sample_rate = int (sample_rate )
3642
@@ -41,24 +47,31 @@ def __init__(
4147 noise_dir = ensure_pack ("pointsource_noises" )
4248 self ._noise_loader = NoiseLoader (noise_dir , sample_rate , preload = noise_preload )
4349
44- self . _base_ops : List [
45- Callable [[torch .Tensor , torch .Tensor | None ], torch .Tensor ]
46- ] = [
50+ # All ops ordered by effectiveness (best first)
51+ all_ops : List [ Callable [[torch .Tensor , torch .Tensor | None ], torch .Tensor ]] = [
52+ # top 3
4753 lambda x , lengths : add_noise (
4854 x , self ._noise_loader , snr_low = 0.0 , snr_high = 10.0
4955 ),
50- lambda x , lengths : add_babble_noise (x ),
51- lambda x , lengths : chunk_swap (x ),
5256 lambda x , lengths : freq_drop (x ),
53- lambda x , lengths : invert_polarity (x ),
54- lambda x , lengths : rand_amp_clip (x ),
55- lambda x , lengths : rand_amp_scale (x ),
56- lambda x , lengths : speed_perturb (x , sample_rate = self .sample_rate ),
5757 lambda x , lengths : time_dropout (
5858 x , sample_rate = self .sample_rate , lengths = lengths
5959 ),
60+ # top 6
61+ lambda x , lengths : speed_perturb (x , sample_rate = self .sample_rate ),
62+ lambda x , lengths : rand_amp_clip (x ),
63+ lambda x , lengths : chunk_swap (x ),
64+ # all 9
65+ lambda x , lengths : add_babble_noise (x ),
66+ lambda x , lengths : rand_amp_scale (x ),
67+ lambda x , lengths : invert_polarity (x ),
6068 ]
6169
70+ if top_k < 1 or top_k > len (all_ops ):
71+ raise ValueError (f"top_k must be between 1 and { len (all_ops )} , got { top_k } " )
72+
73+ self ._base_ops = all_ops [:top_k ]
74+
6275 @torch .no_grad ()
6376 def __call__ (
6477 self ,
0 commit comments