Skip to content

Commit 8ed45a7

Browse files
committed
Merge branch 'downsample' into audio_tutorial
2 parents 7f9d3b0 + b3f37c7 commit 8ed45a7

File tree

5 files changed

+2075
-6
lines changed

5 files changed

+2075
-6
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1494,8 +1494,11 @@ UniqueAVFrame VideoDecoder::convertAudioAVFrameSampleFormatAndSampleRate(
14941494
static_cast<const uint8_t**>(
14951495
const_cast<const uint8_t**>(srcAVFrame->data)),
14961496
srcAVFrame->nb_samples);
1497+
// numConvertedSamples can be 0 if we're downsampling by a great factor and
1498+
// the first frame doesn't contain a lot of samples. It should be handled
1499+
// properly by the caller.
14971500
TORCH_CHECK(
1498-
numConvertedSamples > 0,
1501+
numConvertedSamples >= 0,
14991502
"Error in swr_convert: ",
15001503
getFFMPEGErrorStringFromErrorCode(numConvertedSamples));
15011504

@@ -1522,17 +1525,22 @@ std::optional<torch::Tensor> VideoDecoder::maybeFlushSwrBuffers() {
15221525
return std::nullopt;
15231526
}
15241527

1525-
torch::Tensor lastSamples = torch::empty(
1526-
{getNumChannels(streamInfo.codecContext), numRemainingSamples},
1527-
torch::kFloat32);
1528-
uint8_t* lastSamplesData = static_cast<uint8_t*>(lastSamples.data_ptr());
1528+
auto numChannels = getNumChannels(streamInfo.codecContext);
1529+
torch::Tensor lastSamples =
1530+
torch::empty({numChannels, numRemainingSamples}, torch::kFloat32);
1531+
1532+
std::vector<uint8_t*> outputBuffers(numChannels);
1533+
for (auto i = 0; i < numChannels; i++) {
1534+
outputBuffers[i] = static_cast<uint8_t*>(lastSamples[i].data_ptr());
1535+
}
15291536

15301537
auto actualNumRemainingSamples = swr_convert(
15311538
streamInfo.swrContext.get(),
1532-
&lastSamplesData,
1539+
outputBuffers.data(),
15331540
numRemainingSamples,
15341541
nullptr,
15351542
0);
1543+
15361544
return lastSamples.narrow(
15371545
/*dim=*/1, /*start=*/0, /*length=*/actualNumRemainingSamples);
15381546
}

test/decoders/test_decoders.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
in_fbcode,
2525
NASA_AUDIO,
2626
NASA_AUDIO_MP3,
27+
NASA_AUDIO_MP3_44100,
2728
NASA_VIDEO,
2829
SINE_MONO_S16,
2930
SINE_MONO_S32,
@@ -1157,6 +1158,43 @@ def test_sample_rate_conversion(self, start_seconds, stop_seconds):
11571158
rtol=rtol,
11581159
)
11591160

1161+
def test_sample_rate_conversion_stereo(self):
1162+
# Non-regression test for https://github.com/pytorch/torchcodec/pull/584
1163+
asset = NASA_AUDIO_MP3
1164+
assert asset.sample_rate == 8000
1165+
assert asset.num_channels == 2
1166+
decoder = AudioDecoder(asset.path, sample_rate=44_100)
1167+
decoder.get_samples_played_in_range(start_seconds=0)
1168+
1169+
def test_downsample_empty_frame(self):
1170+
# Non-regression test for
1171+
# https://github.com/pytorch/torchcodec/pull/586: when downsampling by
1172+
# a great factor, if an input frame has a small amount of sample, the
1173+
# resampled frame (as output by swresample) may contain zero sample. We
1174+
# make sure we handle this properly.
1175+
#
1176+
# NASA_AUDIO_MP3_44100's first frame has only 47 samples which triggers
1177+
# the test scenario:
1178+
# ```
1179+
# » ffprobe -v error -hide_banner -select_streams a:0 -show_frames -of json test/resources/nasa_13013.mp4.audio_44100.mp3 | grep nb_samples | head -n 3
1180+
# "nb_samples": 47,
1181+
# "nb_samples": 1152,
1182+
# "nb_samples": 1152,
1183+
# ```
1184+
asset = NASA_AUDIO_MP3_44100
1185+
assert asset.sample_rate == 44_100
1186+
decoder = AudioDecoder(asset.path, sample_rate=8_000)
1187+
frames_44100_to_8000 = decoder.get_samples_played_in_range(start_seconds=0)
1188+
1189+
# Just checking correctness now
1190+
asset = NASA_AUDIO_MP3
1191+
assert asset.sample_rate == 8_000
1192+
decoder = AudioDecoder(asset.path)
1193+
frames_8000 = decoder.get_samples_played_in_range(start_seconds=0)
1194+
torch.testing.assert_close(
1195+
frames_44100_to_8000.data, frames_8000.data, atol=0.03, rtol=0
1196+
)
1197+
11601198
def test_s16_ffmpeg4_bug(self):
11611199
# s16 fails on FFmpeg4 but can be decoded on other versions.
11621200
# Debugging logs show that we're hitting:
205 KB
Binary file not shown.

0 commit comments

Comments
 (0)