66import torch .nn .functional as F
77from tqdm import tqdm
88
9- from wav2aug .utils ._aug_utils import _sample_noise_like
10-
119_EPS = 1e-14
1210_AUDIO_EXTS = {".wav" , ".flac" , ".mp3" , ".ogg" , ".opus" , ".m4a" }
1311
@@ -23,21 +21,17 @@ def _list_audio_files(root: str) -> list[str]:
2321
2422
2523class NoiseLoader :
26- """Noise loader with preload-to-memory or on-demand loading .
24+ """Noise loader that preloads all noise files into CPU RAM .
2725
28- By default, loads all noise files into CPU RAM at initialization for
29- zero-I/O sampling during training. For memory-constrained environments,
30- set preload=False to load files on-demand.
26+ Loads all noise files into CPU RAM at initialization for
27+ zero-I/O sampling during training.
3128
3229 Usage:
33- # Preload mode (default, recommended) :
30+ # Default :
3431 noise_loader = NoiseLoader(noise_dir, sample_rate=16000)
3532
36- # On-demand mode (for memory-constrained systems):
37- noise_loader = NoiseLoader(noise_dir, sample_rate=16000, preload=False)
38-
39- # Custom storage dtype (e.g., for even lower memory):
40- noise_loader = NoiseLoader(noise_dir, sample_rate=16000, storage_dtype=torch.float8_e4m3fn)
33+ # Custom storage dtype (e.g., for lower memory):
34+ noise_loader = NoiseLoader(noise_dir, sample_rate=16000, storage_dtype=torch.float16)
4135
4236 # In training loop:
4337 noisy = add_noise(waveforms, noise_loader, snr_low=0, snr_high=10)
@@ -47,39 +41,27 @@ def __init__(
4741 self ,
4842 noise_dir : str ,
4943 sample_rate : int ,
50- preload : bool = True ,
51- storage_dtype : torch .dtype = torch .float16 ,
44+ storage_dtype : torch .dtype = torch .float32 ,
5245 ):
5346 """Initialize the noise loader.
5447
5548 Args:
5649 noise_dir: Directory containing noise audio files.
5750 sample_rate: Target sample rate for noise.
58- preload: If True (default), load all noise files into CPU RAM at
59- initialization. Sampling becomes a fast tensor slice operation
60- with no I/O. If False, load files on-demand (slower but uses
61- less memory).
6251 storage_dtype: Data type for storing preloaded audio in memory.
63- Defaults to float16 (~650MB for pointsource_noises). Use float32
64- for maximum precision, or float8 variants for minimum memory. Note: In
65- my experiments, float16 halved memory usage in exchange for an
66- extremely tiny performance degradation.
52+ Defaults to float32. Use float16 for lower memory usage.
6753 """
6854 self .noise_dir = noise_dir
6955 self .sample_rate = sample_rate
70- self .preload = preload
7156 self .storage_dtype = storage_dtype
7257 self .files = _list_audio_files (noise_dir )
7358 if not self .files :
7459 raise ValueError (f"No audio files found in { noise_dir } " )
7560
7661 # Preloaded noise bank (1D tensor of all concatenated noise)
77- self ._noise_bank : torch .Tensor | None = None
78-
79- if preload :
80- self ._preload_all ()
62+ self ._noise_bank : torch .Tensor = self ._preload_all ()
8163
82- def _preload_all (self ) -> None :
64+ def _preload_all (self ) -> torch . Tensor :
8365 """Load all noise files into memory."""
8466 from torchcodec .decoders import AudioDecoder
8567
@@ -101,17 +83,7 @@ def _preload_all(self) -> None:
10183 f"No valid audio files could be loaded from { self .noise_dir } "
10284 )
10385
104- self ._noise_bank = torch .cat (chunks , dim = 0 )
105-
106- def _load_one (self ) -> torch .Tensor :
107- """Load a single noise sample directly (no preloading)."""
108- from torchcodec .decoders import AudioDecoder
109-
110- idx = torch .randint (0 , len (self .files ), (1 ,)).item ()
111- dec = AudioDecoder (self .files [idx ], sample_rate = self .sample_rate )
112- samp = dec .get_all_samples ()
113- audio = samp .data .contiguous ().mean (dim = 0 ) # mono, shape [time]
114- return audio
86+ return torch .cat (chunks , dim = 0 )
11587
11688 def get_batch (self , batch_size : int , length : int ) -> torch .Tensor :
11789 """Get a batch of noise samples.
@@ -123,54 +95,23 @@ def get_batch(self, batch_size: int, length: int) -> torch.Tensor:
12395 Returns:
12496 Tensor of shape [batch_size, length] on CPU.
12597 """
126- if self ._noise_bank is not None :
127- # Fast path: slice from preloaded noise bank
128- bank_len = self ._noise_bank .shape [0 ]
129-
130- if bank_len <= length :
131- # Noise bank shorter than requested - pad it
132- noise = self ._noise_bank .unsqueeze (0 ).expand (batch_size , - 1 )
133- noise = F .pad (noise , (0 , length - bank_len ))
134- return noise
135-
136- # Generate random start indices for each sample
137- max_start = bank_len - length
138- starts = torch .randint (0 , max_start + 1 , (batch_size ,))
139-
140- # Vectorized slicing: create index tensor [batch_size, length]
141- # where each row is [start, start+1, ..., start+length-1]
142- offsets = torch .arange (length )
143- indices = starts .unsqueeze (1 ) + offsets .unsqueeze (0 ) # [batch_size, length]
144- return self ._noise_bank [indices ]
145- else :
146- # On-demand loading
147- noises = []
148- for _ in range (batch_size ):
149- noise = self ._load_one ()
150- noise = self ._pad_or_crop (noise , length )
151- noises .append (noise )
152- return torch .stack (noises , dim = 0 )
153-
154- def _pad_or_crop (self , noise : torch .Tensor , length : int ) -> torch .Tensor :
155- """Pad or crop noise to target length."""
156- if noise .shape [0 ] < length :
157- noise = F .pad (noise , (0 , length - noise .shape [0 ]))
158- elif noise .shape [0 ] > length :
159- start = torch .randint (0 , noise .shape [0 ] - length + 1 , (1 ,)).item ()
160- noise = noise [start : start + length ]
161- return noise
162-
163- @property
164- def mode (self ) -> str :
165- """Return current loading mode: 'preload' or 'on-demand'."""
166- return "preload" if self ._noise_bank is not None else "on-demand"
167-
168- @property
169- def preloaded_duration_seconds (self ) -> float | None :
170- """Total duration of preloaded audio in seconds, or None if not preloaded."""
171- if self ._noise_bank is not None :
172- return self ._noise_bank .shape [0 ] / self .sample_rate
173- return None
98+ bank_len = self ._noise_bank .shape [0 ]
99+
100+ if bank_len <= length :
101+ # Noise bank shorter than requested - pad it
102+ noise = self ._noise_bank .unsqueeze (0 ).expand (batch_size , - 1 )
103+ noise = F .pad (noise , (0 , length - bank_len ))
104+ return noise
105+
106+ # Generate random start indices for each sample
107+ max_start = bank_len - length
108+ starts = torch .randint (0 , max_start + 1 , (batch_size ,))
109+
110+ # Vectorized slicing: create index tensor [batch_size, length]
111+ # where each row is [start, start+1, ..., start+length-1]
112+ offsets = torch .arange (length )
113+ indices = starts .unsqueeze (1 ) + offsets .unsqueeze (0 ) # [batch_size, length]
114+ return self ._noise_bank [indices ]
174115
175116
176117@torch .no_grad ()
@@ -239,39 +180,25 @@ def _mix_noise(
239180@torch .no_grad ()
240181def add_noise (
241182 waveforms : torch .Tensor ,
242- sample_rate_or_loader : int | NoiseLoader ,
183+ loader : NoiseLoader ,
243184 * ,
244185 snr_low : float = 0.0 ,
245186 snr_high : float = 10.0 ,
246- noise_dir : str | None = None ,
247- download : bool = True ,
248- pack : str = "pointsource_noises" ,
249187) -> torch .Tensor :
250188 """Add point-source noise to each waveform in the batch.
251189
252190 Args:
253191 waveforms (torch.Tensor): The input waveforms. Shape [batch, time].
254- sample_rate_or_loader: Either the sample rate (int) for legacy behavior,
255- or a NoiseLoader instance for efficient background loading.
192+ loader: A NoiseLoader instance for efficient noise sampling.
256193 snr_low (float, optional): The minimum SNR in dB. Defaults to 0.0.
257194 snr_high (float, optional): The maximum SNR in dB. Defaults to 10.0.
258- noise_dir (str | None, optional): Directory containing noise files.
259- Only used when sample_rate_or_loader is an int. Defaults to None.
260- download (bool, optional): Whether to download noise files if not found.
261- Only used when sample_rate_or_loader is an int. Defaults to True.
262- pack (str, optional): The name of the noise pack to use.
263- Only used when sample_rate_or_loader is an int. Defaults to "pointsource_noises".
264195
265196 Returns:
266197 torch.Tensor: The waveforms with point-source noise added.
267198
268199 Example:
269- # Fast path with NoiseLoader (recommended):
270- loader = NoiseLoader("/path/to/noise", sample_rate=16000, num_workers=4)
200+ loader = NoiseLoader("/path/to/noise", sample_rate=16000)
271201 noisy = add_noise(waveforms, loader, snr_low=0, snr_high=10)
272-
273- # Legacy path (slower, loads from disk each call):
274- noisy = add_noise(waveforms, 16000, snr_low=0, snr_high=10, noise_dir="/path/to/noise")
275202 """
276203 if waveforms .ndim != 2 :
277204 raise AssertionError ("expected waveforms shaped [batch, time]" )
@@ -283,26 +210,8 @@ def add_noise(
283210 device = waveforms .device
284211 dtype = waveforms .dtype
285212
286- if isinstance (sample_rate_or_loader , NoiseLoader ):
287- # Fast path: use the NoiseLoader
288- noise = sample_rate_or_loader .get_batch (batch , total_time )
289- noise = noise .to (device = device , dtype = dtype )
290- else :
291- # Legacy path: load noise synchronously
292- sample_rate = sample_rate_or_loader
293-
294- if noise_dir is None and download :
295- from wav2aug .data .fetch import ensure_pack
296-
297- noise_dir = ensure_pack (pack )
298-
299- noises = []
300- for _ in range (batch ):
301- ref = torch .empty (1 , total_time , dtype = dtype )
302- sample = _sample_noise_like (ref , sample_rate , noise_dir )
303- noise_sample = sample .to (device = device , dtype = dtype ).view (- 1 )
304- noises .append (noise_sample )
305- noise = torch .stack (noises , dim = 0 )
213+ noise = loader .get_batch (batch , total_time )
214+ noise = noise .to (device = device , dtype = dtype )
306215
307216 return _mix_noise (
308217 waveforms ,
0 commit comments