Skip to content

Commit c45c9c6

Browse files
authored
Migrate encoder tests to public Python APIs (#694)
1 parent b4e958f commit c45c9c6

File tree

2 files changed

+222
-254
lines changed

2 files changed

+222
-254
lines changed

test/test_encoders.py

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,28 @@
11
import re
2+
import subprocess
23

34
import pytest
45
import torch
6+
from torchcodec.decoders import AudioDecoder
57

68
from torchcodec.encoders import AudioEncoder
79

10+
from .utils import (
11+
get_ffmpeg_major_version,
12+
in_fbcode,
13+
NASA_AUDIO_MP3,
14+
SINE_MONO_S32,
15+
TestContainerFile,
16+
)
17+
818

919
class TestAudioEncoder:
1020

21+
def decode(self, source) -> torch.Tensor:
22+
if isinstance(source, TestContainerFile):
23+
source = str(source.path)
24+
return AudioDecoder(source).get_all_samples().data
25+
1126
def test_bad_input(self):
1227
with pytest.raises(ValueError, match="Expected samples to be a Tensor"):
1328
AudioEncoder(samples=123, sample_rate=32_000)
@@ -39,3 +54,210 @@ def test_bad_input(self):
3954
match=re.escape(f"Check the desired format? Got format={bad_format}"),
4055
):
4156
encoder.to_tensor(format=bad_format)
57+
58+
@pytest.mark.parametrize("method", ("to_file", "to_tensor"))
59+
def test_bad_input_parametrized(self, method):
60+
valid_params = (
61+
dict(dest="output.mp3") if method == "to_file" else dict(format="mp3")
62+
)
63+
64+
decoder = AudioEncoder(self.decode(NASA_AUDIO_MP3), sample_rate=10)
65+
with pytest.raises(RuntimeError, match="invalid sample rate=10"):
66+
getattr(decoder, method)(**valid_params)
67+
68+
decoder = AudioEncoder(
69+
self.decode(NASA_AUDIO_MP3), sample_rate=NASA_AUDIO_MP3.sample_rate
70+
)
71+
with pytest.raises(RuntimeError, match="bit_rate=-1 must be >= 0"):
72+
getattr(decoder, method)(**valid_params, bit_rate=-1)
73+
74+
bad_num_channels = 10
75+
decoder = AudioEncoder(torch.rand(bad_num_channels, 20), sample_rate=16_000)
76+
with pytest.raises(
77+
RuntimeError, match=f"Trying to encode {bad_num_channels} channels"
78+
):
79+
getattr(decoder, method)(**valid_params)
80+
81+
decoder = AudioEncoder(
82+
self.decode(NASA_AUDIO_MP3), sample_rate=NASA_AUDIO_MP3.sample_rate
83+
)
84+
for num_channels in (0, 3):
85+
with pytest.raises(
86+
RuntimeError,
87+
match=re.escape(
88+
f"Desired number of channels ({num_channels}) is not supported"
89+
),
90+
):
91+
getattr(decoder, method)(**valid_params, num_channels=num_channels)
92+
93+
@pytest.mark.parametrize("method", ("to_file", "to_tensor"))
94+
@pytest.mark.parametrize("format", ("wav", "flac"))
95+
def test_round_trip(self, method, format, tmp_path):
96+
# Check that decode(encode(samples)) == samples on lossless formats
97+
98+
if get_ffmpeg_major_version() == 4 and format == "wav":
99+
pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files")
100+
101+
asset = NASA_AUDIO_MP3
102+
source_samples = self.decode(asset)
103+
104+
encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate)
105+
106+
if method == "to_file":
107+
encoded_path = str(tmp_path / f"output.{format}")
108+
encoded_source = encoded_path
109+
encoder.to_file(dest=encoded_path)
110+
else:
111+
encoded_source = encoder.to_tensor(format=format)
112+
assert encoded_source.dtype == torch.uint8
113+
assert encoded_source.ndim == 1
114+
115+
rtol, atol = (0, 1e-4) if format == "wav" else (None, None)
116+
torch.testing.assert_close(
117+
self.decode(encoded_source), source_samples, rtol=rtol, atol=atol
118+
)
119+
120+
@pytest.mark.skipif(in_fbcode(), reason="TODO: enable ffmpeg CLI")
121+
@pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32))
122+
@pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999))
123+
@pytest.mark.parametrize("num_channels", (None, 1, 2))
124+
@pytest.mark.parametrize("format", ("mp3", "wav", "flac"))
125+
@pytest.mark.parametrize("method", ("to_file", "to_tensor"))
126+
def test_against_cli(self, asset, bit_rate, num_channels, format, method, tmp_path):
127+
# Encodes samples with our encoder and with the FFmpeg CLI, and checks
128+
# that both decoded outputs are equal
129+
130+
if get_ffmpeg_major_version() == 4 and format == "wav":
131+
pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files")
132+
133+
encoded_by_ffmpeg = tmp_path / f"ffmpeg_output.{format}"
134+
subprocess.run(
135+
["ffmpeg", "-i", str(asset.path)]
136+
+ (["-b:a", f"{bit_rate}"] if bit_rate is not None else [])
137+
+ (["-ac", f"{num_channels}"] if num_channels is not None else [])
138+
+ [
139+
str(encoded_by_ffmpeg),
140+
],
141+
capture_output=True,
142+
check=True,
143+
)
144+
145+
encoder = AudioEncoder(self.decode(asset), sample_rate=asset.sample_rate)
146+
params = dict(bit_rate=bit_rate, num_channels=num_channels)
147+
if method == "to_file":
148+
encoded_by_us = tmp_path / f"output.{format}"
149+
encoder.to_file(dest=str(encoded_by_us), **params)
150+
else:
151+
encoded_by_us = encoder.to_tensor(format=format, **params)
152+
153+
if format == "wav":
154+
rtol, atol = 0, 1e-4
155+
elif format == "mp3" and asset is SINE_MONO_S32 and num_channels == 2:
156+
# Not sure why, this one needs slightly higher tol. With default
157+
# tolerances, the check fails on ~1% of the samples, so that's
158+
# probably fine. It might be that the FFmpeg CLI doesn't rely on
159+
# libswresample for converting channels?
160+
rtol, atol = 0, 1e-3
161+
else:
162+
rtol, atol = None, None
163+
torch.testing.assert_close(
164+
self.decode(encoded_by_ffmpeg),
165+
self.decode(encoded_by_us),
166+
rtol=rtol,
167+
atol=atol,
168+
)
169+
170+
@pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32))
171+
@pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999))
172+
@pytest.mark.parametrize("num_channels", (None, 1, 2))
173+
@pytest.mark.parametrize("format", ("mp3", "wav", "flac"))
174+
def test_to_tensor_against_to_file(
175+
self, asset, bit_rate, num_channels, format, tmp_path
176+
):
177+
if get_ffmpeg_major_version() == 4 and format == "wav":
178+
pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files")
179+
180+
encoder = AudioEncoder(self.decode(asset), sample_rate=asset.sample_rate)
181+
182+
params = dict(bit_rate=bit_rate, num_channels=num_channels)
183+
encoded_file = tmp_path / f"output.{format}"
184+
encoder.to_file(dest=str(encoded_file), **params)
185+
encoded_tensor = encoder.to_tensor(
186+
format=format, bit_rate=bit_rate, num_channels=num_channels
187+
)
188+
189+
torch.testing.assert_close(
190+
self.decode(encoded_file), self.decode(encoded_tensor)
191+
)
192+
193+
def test_encode_to_tensor_long_output(self):
194+
# Check that we support re-allocating the output tensor when the encoded
195+
# data is large.
196+
samples = torch.rand(1, int(1e7))
197+
encoded_tensor = AudioEncoder(samples, sample_rate=16_000).to_tensor(
198+
format="flac", bit_rate=44_000
199+
)
200+
201+
# Note: this should be in sync with its C++ counterpart for the test to
202+
# be meaningful.
203+
INITIAL_TENSOR_SIZE = 10_000_000
204+
assert encoded_tensor.numel() > INITIAL_TENSOR_SIZE
205+
206+
torch.testing.assert_close(self.decode(encoded_tensor), samples)
207+
208+
def test_contiguity(self):
209+
# Ensure that 2 waveforms with the same values are encoded in the same
210+
# way, regardless of their memory layout. Here we encode 2 equal
211+
# waveforms, one is row-aligned while the other is column-aligned.
212+
# TODO: Ideally we'd be testing all encoding methods here
213+
214+
num_samples = 10_000 # per channel
215+
contiguous_samples = torch.rand(2, num_samples).contiguous()
216+
assert contiguous_samples.stride() == (num_samples, 1)
217+
218+
params = dict(format="flac", bit_rate=44_000)
219+
encoded_from_contiguous = AudioEncoder(
220+
contiguous_samples, sample_rate=16_000
221+
).to_tensor(**params)
222+
223+
non_contiguous_samples = contiguous_samples.T.contiguous().T
224+
assert non_contiguous_samples.stride() == (1, 2)
225+
226+
torch.testing.assert_close(
227+
contiguous_samples, non_contiguous_samples, rtol=0, atol=0
228+
)
229+
230+
encoded_from_non_contiguous = AudioEncoder(
231+
non_contiguous_samples, sample_rate=16_000
232+
).to_tensor(**params)
233+
234+
torch.testing.assert_close(
235+
encoded_from_contiguous, encoded_from_non_contiguous, rtol=0, atol=0
236+
)
237+
238+
@pytest.mark.parametrize("num_channels_input", (1, 2))
239+
@pytest.mark.parametrize("num_channels_output", (1, 2, None))
240+
@pytest.mark.parametrize("method", ("to_file", "to_tensor"))
241+
def test_num_channels(
242+
self, num_channels_input, num_channels_output, method, tmp_path
243+
):
244+
# We just check that the num_channels parameter is respected.
245+
# Correctness is checked in other tests (like test_against_cli())
246+
247+
sample_rate = 16_000
248+
source_samples = torch.rand(num_channels_input, 1_000)
249+
format = "mp3"
250+
251+
encoder = AudioEncoder(source_samples, sample_rate=sample_rate)
252+
params = dict(num_channels=num_channels_output)
253+
254+
if method == "to_file":
255+
encoded_path = str(tmp_path / f"output.{format}")
256+
encoded_source = encoded_path
257+
encoder.to_file(dest=encoded_path, **params)
258+
else:
259+
encoded_source = encoder.to_tensor(format=format, **params)
260+
261+
if num_channels_output is None:
262+
num_channels_output = num_channels_input
263+
assert self.decode(encoded_source).shape[0] == num_channels_output

0 commit comments

Comments
 (0)