Skip to content

Commit 0641f2e

Browse files
authored
make noise dtype passable #7 (#9)
* make noise_dtype passable and rm dataloader * make noise_dtype passable and rm dataloader * bump version in uv.lock
1 parent 682ad92 commit 0641f2e

File tree

4 files changed

+71
-144
lines changed

4 files changed

+71
-144
lines changed

tests/test_gpu_augmentations.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -63,25 +63,21 @@ def test_freq_drop_no_nan_and_inplace():
6363
assert torch.isnan(out).logical_not().all()
6464

6565

66-
def test_add_noise_with_stub(monkeypatch):
67-
def _stub_noise_like(ref, sample_rate, noise_dir):
68-
return torch.zeros_like(ref)
66+
def test_add_noise_with_mock_loader():
67+
"""Test add_noise with a mock NoiseLoader."""
68+
from unittest.mock import MagicMock
6969

70-
monkeypatch.setattr(
71-
"wav2aug.gpu.noise_addition._sample_noise_like", _stub_noise_like
72-
)
7370
waveforms = torch.ones(2, 128, device=DEVICE, dtype=torch.float32)
7471
ptr = waveforms.data_ptr()
75-
out = add_noise(
76-
waveforms,
77-
16_000, # sample_rate as positional argument
78-
snr_low=0.0,
79-
snr_high=0.0,
80-
download=False,
81-
noise_dir="ignored",
82-
)
72+
73+
# Create mock loader that returns zeros
74+
mock_loader = MagicMock()
75+
mock_loader.get_batch.return_value = torch.zeros(2, 128)
76+
77+
out = add_noise(waveforms, mock_loader, snr_low=0.0, snr_high=0.0)
8378
assert out.data_ptr() == ptr
8479
assert torch.isfinite(out).all()
80+
mock_loader.get_batch.assert_called_once_with(2, 128)
8581

8682

8783
def test_add_babble_noise_identity_for_singleton_batch():
@@ -127,7 +123,7 @@ def test_time_dropout_zeroes_segments():
127123

128124

129125
def test_wav2aug_runs_with_stubbed_noise(monkeypatch):
130-
def _noop_add_noise(waveforms, sample_rate, **kwargs):
126+
def _noop_add_noise(waveforms, loader, **kwargs):
131127
return waveforms
132128

133129
monkeypatch.setattr("wav2aug.gpu.wav2aug.add_noise", _noop_add_noise)
@@ -168,3 +164,22 @@ def test_wav2aug_top_k_invalid_raises():
168164

169165
with pytest.raises(ValueError, match="top_k must be between 1 and 9"):
170166
Wav2Aug(sample_rate=16_000, top_k=10)
167+
168+
169+
def test_wav2aug_noise_dtype(monkeypatch):
170+
"""Test that noise_dtype is passed to NoiseLoader."""
171+
172+
def _noop_add_noise(waveforms, loader, **kwargs):
173+
return waveforms
174+
175+
monkeypatch.setattr("wav2aug.gpu.wav2aug.add_noise", _noop_add_noise)
176+
177+
# Default should be float32
178+
aug = Wav2Aug(sample_rate=16_000)
179+
assert aug.noise_dtype == torch.float32
180+
assert aug._noise_loader.storage_dtype == torch.float32
181+
182+
# Custom dtype should be passed through
183+
aug = Wav2Aug(sample_rate=16_000, noise_dtype=torch.float16)
184+
assert aug.noise_dtype == torch.float16
185+
assert aug._noise_loader.storage_dtype == torch.float16

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

wav2aug/gpu/noise_addition.py

Lines changed: 33 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
import torch.nn.functional as F
77
from tqdm import tqdm
88

9-
from wav2aug.utils._aug_utils import _sample_noise_like
10-
119
_EPS = 1e-14
1210
_AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".opus", ".m4a"}
1311

@@ -23,21 +21,17 @@ def _list_audio_files(root: str) -> list[str]:
2321

2422

2523
class NoiseLoader:
26-
"""Noise loader with preload-to-memory or on-demand loading.
24+
"""Noise loader that preloads all noise files into CPU RAM.
2725
28-
By default, loads all noise files into CPU RAM at initialization for
29-
zero-I/O sampling during training. For memory-constrained environments,
30-
set preload=False to load files on-demand.
26+
Loads all noise files into CPU RAM at initialization for
27+
zero-I/O sampling during training.
3128
3229
Usage:
33-
# Preload mode (default, recommended):
30+
# Default:
3431
noise_loader = NoiseLoader(noise_dir, sample_rate=16000)
3532
36-
# On-demand mode (for memory-constrained systems):
37-
noise_loader = NoiseLoader(noise_dir, sample_rate=16000, preload=False)
38-
39-
# Custom storage dtype (e.g., for even lower memory):
40-
noise_loader = NoiseLoader(noise_dir, sample_rate=16000, storage_dtype=torch.float8_e4m3fn)
33+
# Custom storage dtype (e.g., for lower memory):
34+
noise_loader = NoiseLoader(noise_dir, sample_rate=16000, storage_dtype=torch.float16)
4135
4236
# In training loop:
4337
noisy = add_noise(waveforms, noise_loader, snr_low=0, snr_high=10)
@@ -47,39 +41,27 @@ def __init__(
4741
self,
4842
noise_dir: str,
4943
sample_rate: int,
50-
preload: bool = True,
51-
storage_dtype: torch.dtype = torch.float16,
44+
storage_dtype: torch.dtype = torch.float32,
5245
):
5346
"""Initialize the noise loader.
5447
5548
Args:
5649
noise_dir: Directory containing noise audio files.
5750
sample_rate: Target sample rate for noise.
58-
preload: If True (default), load all noise files into CPU RAM at
59-
initialization. Sampling becomes a fast tensor slice operation
60-
with no I/O. If False, load files on-demand (slower but uses
61-
less memory).
6251
storage_dtype: Data type for storing preloaded audio in memory.
63-
Defaults to float16 (~650MB for pointsource_noises). Use float32
64-
for maximum precision, or float8 variants for minimum memory. Note: In
65-
my experiments, float16 halved memory usage in exchange for an
66-
extremely tiny performance degradation.
52+
Defaults to float32. Use float16 for lower memory usage.
6753
"""
6854
self.noise_dir = noise_dir
6955
self.sample_rate = sample_rate
70-
self.preload = preload
7156
self.storage_dtype = storage_dtype
7257
self.files = _list_audio_files(noise_dir)
7358
if not self.files:
7459
raise ValueError(f"No audio files found in {noise_dir}")
7560

7661
# Preloaded noise bank (1D tensor of all concatenated noise)
77-
self._noise_bank: torch.Tensor | None = None
78-
79-
if preload:
80-
self._preload_all()
62+
self._noise_bank: torch.Tensor = self._preload_all()
8163

82-
def _preload_all(self) -> None:
64+
def _preload_all(self) -> torch.Tensor:
8365
"""Load all noise files into memory."""
8466
from torchcodec.decoders import AudioDecoder
8567

@@ -101,17 +83,7 @@ def _preload_all(self) -> None:
10183
f"No valid audio files could be loaded from {self.noise_dir}"
10284
)
10385

104-
self._noise_bank = torch.cat(chunks, dim=0)
105-
106-
def _load_one(self) -> torch.Tensor:
107-
"""Load a single noise sample directly (no preloading)."""
108-
from torchcodec.decoders import AudioDecoder
109-
110-
idx = torch.randint(0, len(self.files), (1,)).item()
111-
dec = AudioDecoder(self.files[idx], sample_rate=self.sample_rate)
112-
samp = dec.get_all_samples()
113-
audio = samp.data.contiguous().mean(dim=0) # mono, shape [time]
114-
return audio
86+
return torch.cat(chunks, dim=0)
11587

11688
def get_batch(self, batch_size: int, length: int) -> torch.Tensor:
11789
"""Get a batch of noise samples.
@@ -123,54 +95,23 @@ def get_batch(self, batch_size: int, length: int) -> torch.Tensor:
12395
Returns:
12496
Tensor of shape [batch_size, length] on CPU.
12597
"""
126-
if self._noise_bank is not None:
127-
# Fast path: slice from preloaded noise bank
128-
bank_len = self._noise_bank.shape[0]
129-
130-
if bank_len <= length:
131-
# Noise bank shorter than requested - pad it
132-
noise = self._noise_bank.unsqueeze(0).expand(batch_size, -1)
133-
noise = F.pad(noise, (0, length - bank_len))
134-
return noise
135-
136-
# Generate random start indices for each sample
137-
max_start = bank_len - length
138-
starts = torch.randint(0, max_start + 1, (batch_size,))
139-
140-
# Vectorized slicing: create index tensor [batch_size, length]
141-
# where each row is [start, start+1, ..., start+length-1]
142-
offsets = torch.arange(length)
143-
indices = starts.unsqueeze(1) + offsets.unsqueeze(0) # [batch_size, length]
144-
return self._noise_bank[indices]
145-
else:
146-
# On-demand loading
147-
noises = []
148-
for _ in range(batch_size):
149-
noise = self._load_one()
150-
noise = self._pad_or_crop(noise, length)
151-
noises.append(noise)
152-
return torch.stack(noises, dim=0)
153-
154-
def _pad_or_crop(self, noise: torch.Tensor, length: int) -> torch.Tensor:
155-
"""Pad or crop noise to target length."""
156-
if noise.shape[0] < length:
157-
noise = F.pad(noise, (0, length - noise.shape[0]))
158-
elif noise.shape[0] > length:
159-
start = torch.randint(0, noise.shape[0] - length + 1, (1,)).item()
160-
noise = noise[start : start + length]
161-
return noise
162-
163-
@property
164-
def mode(self) -> str:
165-
"""Return current loading mode: 'preload' or 'on-demand'."""
166-
return "preload" if self._noise_bank is not None else "on-demand"
167-
168-
@property
169-
def preloaded_duration_seconds(self) -> float | None:
170-
"""Total duration of preloaded audio in seconds, or None if not preloaded."""
171-
if self._noise_bank is not None:
172-
return self._noise_bank.shape[0] / self.sample_rate
173-
return None
98+
bank_len = self._noise_bank.shape[0]
99+
100+
if bank_len <= length:
101+
# Noise bank shorter than requested - pad it
102+
noise = self._noise_bank.unsqueeze(0).expand(batch_size, -1)
103+
noise = F.pad(noise, (0, length - bank_len))
104+
return noise
105+
106+
# Generate random start indices for each sample
107+
max_start = bank_len - length
108+
starts = torch.randint(0, max_start + 1, (batch_size,))
109+
110+
# Vectorized slicing: create index tensor [batch_size, length]
111+
# where each row is [start, start+1, ..., start+length-1]
112+
offsets = torch.arange(length)
113+
indices = starts.unsqueeze(1) + offsets.unsqueeze(0) # [batch_size, length]
114+
return self._noise_bank[indices]
174115

175116

176117
@torch.no_grad()
@@ -239,39 +180,25 @@ def _mix_noise(
239180
@torch.no_grad()
240181
def add_noise(
241182
waveforms: torch.Tensor,
242-
sample_rate_or_loader: int | NoiseLoader,
183+
loader: NoiseLoader,
243184
*,
244185
snr_low: float = 0.0,
245186
snr_high: float = 10.0,
246-
noise_dir: str | None = None,
247-
download: bool = True,
248-
pack: str = "pointsource_noises",
249187
) -> torch.Tensor:
250188
"""Add point-source noise to each waveform in the batch.
251189
252190
Args:
253191
waveforms (torch.Tensor): The input waveforms. Shape [batch, time].
254-
sample_rate_or_loader: Either the sample rate (int) for legacy behavior,
255-
or a NoiseLoader instance for efficient background loading.
192+
loader: A NoiseLoader instance for efficient noise sampling.
256193
snr_low (float, optional): The minimum SNR in dB. Defaults to 0.0.
257194
snr_high (float, optional): The maximum SNR in dB. Defaults to 10.0.
258-
noise_dir (str | None, optional): Directory containing noise files.
259-
Only used when sample_rate_or_loader is an int. Defaults to None.
260-
download (bool, optional): Whether to download noise files if not found.
261-
Only used when sample_rate_or_loader is an int. Defaults to True.
262-
pack (str, optional): The name of the noise pack to use.
263-
Only used when sample_rate_or_loader is an int. Defaults to "pointsource_noises".
264195
265196
Returns:
266197
torch.Tensor: The waveforms with point-source noise added.
267198
268199
Example:
269-
# Fast path with NoiseLoader (recommended):
270-
loader = NoiseLoader("/path/to/noise", sample_rate=16000, num_workers=4)
200+
loader = NoiseLoader("/path/to/noise", sample_rate=16000)
271201
noisy = add_noise(waveforms, loader, snr_low=0, snr_high=10)
272-
273-
# Legacy path (slower, loads from disk each call):
274-
noisy = add_noise(waveforms, 16000, snr_low=0, snr_high=10, noise_dir="/path/to/noise")
275202
"""
276203
if waveforms.ndim != 2:
277204
raise AssertionError("expected waveforms shaped [batch, time]")
@@ -283,26 +210,8 @@ def add_noise(
283210
device = waveforms.device
284211
dtype = waveforms.dtype
285212

286-
if isinstance(sample_rate_or_loader, NoiseLoader):
287-
# Fast path: use the NoiseLoader
288-
noise = sample_rate_or_loader.get_batch(batch, total_time)
289-
noise = noise.to(device=device, dtype=dtype)
290-
else:
291-
# Legacy path: load noise synchronously
292-
sample_rate = sample_rate_or_loader
293-
294-
if noise_dir is None and download:
295-
from wav2aug.data.fetch import ensure_pack
296-
297-
noise_dir = ensure_pack(pack)
298-
299-
noises = []
300-
for _ in range(batch):
301-
ref = torch.empty(1, total_time, dtype=dtype)
302-
sample = _sample_noise_like(ref, sample_rate, noise_dir)
303-
noise_sample = sample.to(device=device, dtype=dtype).view(-1)
304-
noises.append(noise_sample)
305-
noise = torch.stack(noises, dim=0)
213+
noise = loader.get_batch(batch, total_time)
214+
noise = noise.to(device=device, dtype=dtype)
306215

307216
return _mix_noise(
308217
waveforms,

wav2aug/gpu/wav2aug.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,31 +21,34 @@ def __init__(
2121
self,
2222
sample_rate: int,
2323
noise_dir: str | None = None,
24-
noise_preload: bool = True,
2524
top_k: int = 9,
25+
noise_dtype: torch.dtype = torch.float32,
2626
) -> None:
2727
"""Initialize Wav2Aug.
2828
2929
Args:
3030
sample_rate: Audio sample rate in Hz.
3131
noise_dir: Directory containing noise files. If None, will use the
3232
default cached noise pack (auto-downloaded if needed).
33-
noise_preload: If True (default), preload all noise files into CPU RAM
34-
at initialization for fast sampling. If False, load files on-demand.
3533
top_k: Number of top augmentations to use, ordered by effectiveness.
3634
Default is 9 (all augmentations). Common values: 3, 6, or 9.
3735
Order (best to worst): Noise Addition, Freq Drop, Time Drop,
3836
Speed Perturb, Amp Clip, Chunk Swap, Babble Noise, Amp Scale,
3937
Polarity Inversion.
38+
noise_dtype: Data type for storing preloaded noise in memory.
39+
Defaults to float32. Use float16 for memory efficiency.
4040
"""
4141
self.sample_rate = int(sample_rate)
42+
self.noise_dtype = noise_dtype
4243

4344
# Initialize noise loader
4445
if noise_dir is None:
4546
from wav2aug.data.fetch import ensure_pack
4647

4748
noise_dir = ensure_pack("pointsource_noises")
48-
self._noise_loader = NoiseLoader(noise_dir, sample_rate, preload=noise_preload)
49+
self._noise_loader = NoiseLoader(
50+
noise_dir, sample_rate, storage_dtype=noise_dtype
51+
)
4952

5053
# All ops ordered by effectiveness (best first)
5154
all_ops: List[Callable[[torch.Tensor, torch.Tensor | None], torch.Tensor]] = [

0 commit comments

Comments
 (0)