Skip to content

Commit 96e5e60

Browse files
committed
Merge branch 'main' of github.com:pytorch/torchcodec into migrate_encoding_test
2 parents 5d9eb54 + b4e958f commit 96e5e60

File tree

3 files changed

+246
-4
lines changed

3 files changed

+246
-4
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ void AudioEncoder::initializeEncoder(
179179

180180
desiredNumChannels_ = static_cast<int>(numChannels.value_or(wf_.sizes()[0]));
181181
validateNumChannels(*avCodec, desiredNumChannels_);
182+
// The avCodecContext layout defines the layout of the encoded output, it's
183+
// not related to the input sampes.
182184
setDefaultChannelLayout(avCodecContext_, desiredNumChannels_);
183185

184186
validateSampleRate(*avCodec, sampleRate);
@@ -233,6 +235,8 @@ void AudioEncoder::encode() {
233235
avFrame->format = AV_SAMPLE_FMT_FLTP;
234236
avFrame->sample_rate = avCodecContext_->sample_rate;
235237
avFrame->pts = 0;
238+
// We set the channel layout of the frame to the default layout corresponding
239+
// to the input samples' number of channels
236240
setDefaultChannelLayout(avFrame, static_cast<int>(wf_.sizes()[0]));
237241

238242
auto status = av_frame_get_buffer(avFrame.get(), 0);

src/torchcodec/_core/FFMPEGCommon.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,17 @@ void validateNumChannels(const AVCodec& avCodec, int numChannels) {
107107
// eventually raise.
108108
return;
109109
}
110-
for (auto i = 0; avCodec.ch_layouts[i].order != AV_CHANNEL_ORDER_UNSPEC;
111-
++i) {
110+
// FFmpeg doc indicate that the ch_layouts array is terminated by a zeroed
111+
// layout, so checking for nb_channels == 0 should indicate its end.
112+
for (auto i = 0; avCodec.ch_layouts[i].nb_channels != 0; ++i) {
112113
if (numChannels == avCodec.ch_layouts[i].nb_channels) {
113114
return;
114115
}
115116
}
117+
// At this point it seems that the encoder doesn't support the requested
118+
// number of channels, so we error out.
116119
std::stringstream supportedNumChannels;
117-
for (auto i = 0; avCodec.ch_layouts[i].order != AV_CHANNEL_ORDER_UNSPEC;
118-
++i) {
120+
for (auto i = 0; avCodec.ch_layouts[i].nb_channels != 0; ++i) {
119121
if (i > 0) {
120122
supportedNumChannels << ", ";
121123
}
@@ -132,6 +134,8 @@ void validateNumChannels(const AVCodec& avCodec, int numChannels) {
132134
return;
133135
}
134136
}
137+
// At this point it seems that the encoder doesn't support the requested
138+
// number of channels, so we error out.
135139
std::stringstream supportedNumChannels;
136140
for (auto i = 0; avCodec.channel_layouts[i] != 0; ++i) {
137141
if (i > 0) {

test/test_ops.py

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import io
88
import os
9+
import re
910
from functools import partial
1011

1112
os.environ["TORCH_LOGS"] = "output_code"
@@ -1119,6 +1120,239 @@ def test_bad_input(self, tmp_path):
11191120
wf=torch.rand(2, 10), sample_rate=10, filename="./file.bad_extension"
11201121
)
11211122

1123+
with pytest.raises(RuntimeError, match="invalid sample rate=10"):
1124+
encode_audio_to_file(
1125+
wf=self.decode(NASA_AUDIO_MP3),
1126+
sample_rate=10,
1127+
filename=valid_output_file,
1128+
)
1129+
with pytest.raises(RuntimeError, match="bit_rate=-1 must be >= 0"):
1130+
encode_audio_to_file(
1131+
wf=self.decode(NASA_AUDIO_MP3),
1132+
sample_rate=NASA_AUDIO_MP3.sample_rate,
1133+
filename=valid_output_file,
1134+
bit_rate=-1, # bad
1135+
)
1136+
1137+
with pytest.raises(RuntimeError, match="Trying to encode 10 channels"):
1138+
encode_audio_to_file(
1139+
wf=torch.rand(10, 20), sample_rate=10, filename="doesnt_matter"
1140+
)
1141+
1142+
for num_channels in (0, 3):
1143+
with pytest.raises(
1144+
RuntimeError,
1145+
match=re.escape(
1146+
f"Desired number of channels ({num_channels}) is not supported"
1147+
),
1148+
):
1149+
encode_audio_to_file(
1150+
wf=torch.rand(2, 10),
1151+
sample_rate=16_000,
1152+
filename="ok.mp3",
1153+
num_channels=num_channels,
1154+
)
1155+
1156+
@pytest.mark.parametrize(
1157+
"encode_method", (encode_audio_to_file, encode_audio_to_tensor)
1158+
)
1159+
@pytest.mark.parametrize("output_format", ("wav", "flac"))
1160+
def test_round_trip(self, encode_method, output_format, tmp_path):
1161+
# Check that decode(encode(samples)) == samples on lossless formats
1162+
1163+
if get_ffmpeg_major_version() == 4 and output_format == "wav":
1164+
pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files")
1165+
1166+
asset = NASA_AUDIO_MP3
1167+
source_samples = self.decode(asset)
1168+
1169+
if encode_method is encode_audio_to_file:
1170+
encoded_path = tmp_path / f"output.{output_format}"
1171+
encode_audio_to_file(
1172+
wf=source_samples,
1173+
sample_rate=asset.sample_rate,
1174+
filename=str(encoded_path),
1175+
)
1176+
encoded_source = encoded_path
1177+
else:
1178+
encoded_source = encode_audio_to_tensor(
1179+
wf=source_samples, sample_rate=asset.sample_rate, format=output_format
1180+
)
1181+
assert encoded_source.dtype == torch.uint8
1182+
assert encoded_source.ndim == 1
1183+
1184+
rtol, atol = (0, 1e-4) if output_format == "wav" else (None, None)
1185+
torch.testing.assert_close(
1186+
self.decode(encoded_source), source_samples, rtol=rtol, atol=atol
1187+
)
1188+
1189+
@pytest.mark.skipif(in_fbcode(), reason="TODO: enable ffmpeg CLI")
1190+
@pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32))
1191+
@pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999))
1192+
@pytest.mark.parametrize("num_channels", (None, 1, 2))
1193+
@pytest.mark.parametrize("output_format", ("mp3", "wav", "flac"))
1194+
def test_against_cli(self, asset, bit_rate, num_channels, output_format, tmp_path):
1195+
# Encodes samples with our encoder and with the FFmpeg CLI, and checks
1196+
# that both decoded outputs are equal
1197+
1198+
if get_ffmpeg_major_version() == 4 and output_format == "wav":
1199+
pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files")
1200+
1201+
encoded_by_ffmpeg = tmp_path / f"ffmpeg_output.{output_format}"
1202+
subprocess.run(
1203+
["ffmpeg", "-i", str(asset.path)]
1204+
+ (["-b:a", f"{bit_rate}"] if bit_rate is not None else [])
1205+
+ (["-ac", f"{num_channels}"] if num_channels is not None else [])
1206+
+ [
1207+
str(encoded_by_ffmpeg),
1208+
],
1209+
capture_output=True,
1210+
check=True,
1211+
)
1212+
1213+
encoded_by_us = tmp_path / f"our_output.{output_format}"
1214+
encode_audio_to_file(
1215+
wf=self.decode(asset),
1216+
sample_rate=asset.sample_rate,
1217+
filename=str(encoded_by_us),
1218+
bit_rate=bit_rate,
1219+
num_channels=num_channels,
1220+
)
1221+
1222+
if output_format == "wav":
1223+
rtol, atol = 0, 1e-4
1224+
elif output_format == "mp3" and asset is SINE_MONO_S32 and num_channels == 2:
1225+
# Not sure why, this one needs slightly higher tol. With default
1226+
# tolerances, the check fails on ~1% of the samples, so that's
1227+
# probably fine. It might be that the FFmpeg CLI doesn't rely on
1228+
# libswresample for converting channels?
1229+
rtol, atol = 0, 1e-3
1230+
else:
1231+
rtol, atol = None, None
1232+
torch.testing.assert_close(
1233+
self.decode(encoded_by_ffmpeg),
1234+
self.decode(encoded_by_us),
1235+
rtol=rtol,
1236+
atol=atol,
1237+
)
1238+
1239+
@pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32))
1240+
@pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999))
1241+
@pytest.mark.parametrize("num_channels", (None, 1, 2))
1242+
@pytest.mark.parametrize("output_format", ("mp3", "wav", "flac"))
1243+
def test_tensor_against_file(
1244+
self, asset, bit_rate, num_channels, output_format, tmp_path
1245+
):
1246+
if get_ffmpeg_major_version() == 4 and output_format == "wav":
1247+
pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files")
1248+
1249+
encoded_file = tmp_path / f"our_output.{output_format}"
1250+
encode_audio_to_file(
1251+
wf=self.decode(asset),
1252+
sample_rate=asset.sample_rate,
1253+
filename=str(encoded_file),
1254+
bit_rate=bit_rate,
1255+
num_channels=num_channels,
1256+
)
1257+
1258+
encoded_tensor = encode_audio_to_tensor(
1259+
wf=self.decode(asset),
1260+
sample_rate=asset.sample_rate,
1261+
format=output_format,
1262+
bit_rate=bit_rate,
1263+
num_channels=num_channels,
1264+
)
1265+
1266+
torch.testing.assert_close(
1267+
self.decode(encoded_file), self.decode(encoded_tensor)
1268+
)
1269+
1270+
def test_encode_to_tensor_long_output(self):
1271+
# Check that we support re-allocating the output tensor when the encoded
1272+
# data is large.
1273+
samples = torch.rand(1, int(1e7))
1274+
encoded_tensor = encode_audio_to_tensor(
1275+
wf=samples,
1276+
sample_rate=16_000,
1277+
format="flac",
1278+
bit_rate=44_000,
1279+
)
1280+
# Note: this should be in sync with its C++ counterpart for the test to
1281+
# be meaningful.
1282+
INITIAL_TENSOR_SIZE = 10_000_000
1283+
assert encoded_tensor.numel() > INITIAL_TENSOR_SIZE
1284+
1285+
torch.testing.assert_close(self.decode(encoded_tensor), samples)
1286+
1287+
def test_contiguity(self):
1288+
# Ensure that 2 waveforms with the same values are encoded in the same
1289+
# way, regardless of their memory layout. Here we encode 2 equal
1290+
# waveforms, one is row-aligned while the other is column-aligned.
1291+
1292+
num_samples = 10_000 # per channel
1293+
contiguous_samples = torch.rand(2, num_samples).contiguous()
1294+
assert contiguous_samples.stride() == (num_samples, 1)
1295+
1296+
encoded_from_contiguous = encode_audio_to_tensor(
1297+
wf=contiguous_samples,
1298+
sample_rate=16_000,
1299+
format="flac",
1300+
bit_rate=44_000,
1301+
)
1302+
non_contiguous_samples = contiguous_samples.T.contiguous().T
1303+
assert non_contiguous_samples.stride() == (1, 2)
1304+
1305+
torch.testing.assert_close(
1306+
contiguous_samples, non_contiguous_samples, rtol=0, atol=0
1307+
)
1308+
1309+
encoded_from_non_contiguous = encode_audio_to_tensor(
1310+
wf=non_contiguous_samples,
1311+
sample_rate=16_000,
1312+
format="flac",
1313+
bit_rate=44_000,
1314+
)
1315+
1316+
torch.testing.assert_close(
1317+
encoded_from_contiguous, encoded_from_non_contiguous, rtol=0, atol=0
1318+
)
1319+
1320+
@pytest.mark.parametrize("num_channels_input", (1, 2))
1321+
@pytest.mark.parametrize("num_channels_output", (1, 2, None))
1322+
@pytest.mark.parametrize(
1323+
"encode_method", (encode_audio_to_file, encode_audio_to_tensor)
1324+
)
1325+
def test_num_channels(
1326+
self, num_channels_input, num_channels_output, encode_method, tmp_path
1327+
):
1328+
# We just check that the num_channels parameter is respected.
1329+
# Correctness is checked in other tests (like test_against_cli())
1330+
1331+
sample_rate = 16_000
1332+
source_samples = torch.rand(num_channels_input, 1_000)
1333+
format = "mp3"
1334+
1335+
if encode_method is encode_audio_to_file:
1336+
encoded_path = tmp_path / f"output.{format}"
1337+
encode_audio_to_file(
1338+
wf=source_samples,
1339+
sample_rate=sample_rate,
1340+
filename=str(encoded_path),
1341+
num_channels=num_channels_output,
1342+
)
1343+
encoded_source = encoded_path
1344+
else:
1345+
encoded_source = encode_audio_to_tensor(
1346+
wf=source_samples,
1347+
sample_rate=sample_rate,
1348+
format=format,
1349+
num_channels=num_channels_output,
1350+
)
1351+
1352+
if num_channels_output is None:
1353+
num_channels_output = num_channels_input
1354+
assert self.decode(encoded_source).shape[0] == num_channels_output
1355+
11221356

11231357
if __name__ == "__main__":
11241358
pytest.main()

0 commit comments

Comments
 (0)