Skip to content

Commit f6a7f4e

Browse files
committed
WIP
1 parent 8b2ad5b commit f6a7f4e

File tree

5 files changed

+109
-8
lines changed

5 files changed

+109
-8
lines changed

src/torchcodec/decoders/_core/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ set(CMAKE_CXX_STANDARD 17)
44
set(CMAKE_CXX_STANDARD_REQUIRED ON)
55

66
find_package(Torch REQUIRED)
7-
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Werror ${TORCH_CXX_FLAGS}")
7+
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Werror ${TORCH_CXX_FLAGS}")
8+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra ${TORCH_CXX_FLAGS}")
89
find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development)
910

1011
function(make_torchcodec_library library_name ffmpeg_target)
@@ -97,6 +98,7 @@ else()
9798
libavformat
9899
libavcodec
99100
libavutil
101+
libswresample
100102
libswscale
101103
)
102104

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ extern "C" {
2323
#include <libavutil/imgutils.h>
2424
#include <libavutil/log.h>
2525
#include <libavutil/pixdesc.h>
26+
#include <libswresample/swresample.h>
2627
#include <libswscale/swscale.h>
2728
}
2829

@@ -541,14 +542,18 @@ void VideoDecoder::addVideoStream(
541542
videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary);
542543
}
543544

544-
void VideoDecoder::addAudioStream(int streamIndex) {
545+
void VideoDecoder::addAudioStream(
546+
int streamIndex,
547+
const AudioStreamOptions& audioStreamOptions) {
545548
TORCH_CHECK(
546549
seekMode_ == SeekMode::approximate,
547550
"seek_mode must be 'approximate' for audio streams.");
548551

549552
addStream(streamIndex, AVMEDIA_TYPE_AUDIO);
550553

551554
auto& streamInfo = streamInfos_[activeStreamIndex_];
555+
streamInfo.audioStreamOptions = audioStreamOptions;
556+
552557
auto& streamMetadata =
553558
containerMetadata_.allStreamMetadata[activeStreamIndex_];
554559
streamMetadata.sampleRate =
@@ -1332,6 +1337,82 @@ void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
13321337
"pre-allocated audio tensor not supported yet.");
13331338

13341339
const AVFrame* avFrame = avFrameStream.avFrame.get();
1340+
AVFrame* output_frame = nullptr;
1341+
SwrContext* swr_ctx = NULL; // TODO RAII
1342+
1343+
const auto sampleRate =
1344+
streamInfos_[activeStreamIndex_].audioStreamOptions.sampleRate;
1345+
if (sampleRate.has_value()) {
1346+
int outRate = static_cast<int>(*sampleRate);
1347+
auto& streamMetadata =
1348+
containerMetadata_.allStreamMetadata[activeStreamIndex_];
1349+
int inRate = static_cast<int>(streamMetadata.sampleRate.value());
1350+
1351+
printf("RESAMPLEING FROM %d to %d\n", outRate, inRate);
1352+
AVSampleFormat sampleFormat = AV_SAMPLE_FMT_FLTP;
1353+
1354+
AVChannelLayout stereoLayout = AV_CHANNEL_LAYOUT_STEREO;
1355+
const AVChannelLayout* chl = &stereoLayout;
1356+
1357+
int status = swr_alloc_set_opts2(
1358+
&swr_ctx,
1359+
chl,
1360+
sampleFormat,
1361+
outRate,
1362+
chl,
1363+
sampleFormat,
1364+
inRate,
1365+
0,
1366+
NULL);
1367+
1368+
TORCH_CHECK(status == 0, "IS NULL");
1369+
1370+
if (swr_init(swr_ctx) < 0) {
1371+
swr_free(&swr_ctx);
1372+
TORCH_CHECK(false, "Failed to initialize the resampling context\n");
1373+
}
1374+
1375+
// Allocate output frame
1376+
output_frame = av_frame_alloc();
1377+
if (!output_frame) {
1378+
swr_free(&swr_ctx);
1379+
TORCH_CHECK(false, "Could not allocate output frame\n");
1380+
}
1381+
output_frame->ch_layout = stereoLayout;
1382+
output_frame->sample_rate = outRate;
1383+
output_frame->format = sampleFormat;
1384+
1385+
output_frame->nb_samples = av_rescale_rnd(
1386+
swr_get_delay(swr_ctx, inRate) + avFrame->nb_samples,
1387+
outRate,
1388+
inRate,
1389+
AV_ROUND_UP);
1390+
1391+
if (av_frame_get_buffer(output_frame, 0) < 0) {
1392+
av_frame_free(&output_frame);
1393+
swr_free(&swr_ctx);
1394+
TORCH_CHECK(false, "Could not allocate output frame samples");
1395+
}
1396+
1397+
int ret = swr_convert(
1398+
swr_ctx,
1399+
output_frame->data,
1400+
output_frame->nb_samples,
1401+
(const uint8_t**)avFrame->data,
1402+
avFrame->nb_samples);
1403+
if (ret < 0) {
1404+
av_frame_free(&output_frame);
1405+
swr_free(&swr_ctx);
1406+
TORCH_CHECK(false, "Error while converting\n");
1407+
}
1408+
1409+
printf(
1410+
"nb_samples: %d %d\n", avFrame->nb_samples, output_frame->nb_samples);
1411+
1412+
avFrame = output_frame; // lmao
1413+
} else {
1414+
printf("NO RESAMPLING\n");
1415+
}
13351416

13361417
auto numSamples = avFrame->nb_samples; // per channel
13371418
auto numChannels = getNumChannels(avFrame);
@@ -1360,6 +1441,10 @@ void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
13601441
av_get_sample_fmt_name(format));
13611442
}
13621443
frameOutput.data = outputData;
1444+
1445+
// TODO
1446+
av_frame_free(&output_frame);
1447+
swr_free(&swr_ctx);
13631448
}
13641449

13651450
// --------------------------------------------------------------------------

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,19 @@ class VideoDecoder {
138138
torch::Device device = torch::kCPU;
139139
};
140140

141+
struct AudioStreamOptions {
142+
AudioStreamOptions() {}
143+
144+
// explicit AudioStreamOptions(const std::string& optionsString);
145+
std::optional<int> sampleRate;
146+
};
147+
141148
void addVideoStream(
142149
int streamIndex,
143150
const VideoStreamOptions& videoStreamOptions = VideoStreamOptions());
144-
void addAudioStream(int streamIndex);
151+
void addAudioStream(
152+
int streamIndex,
153+
const AudioStreamOptions& audioStreamOptions = AudioStreamOptions());
145154

146155
// --------------------------------------------------------------------------
147156
// DECODING AND SEEKING APIs
@@ -221,7 +230,6 @@ class VideoDecoder {
221230
double startSeconds,
222231
double stopSeconds);
223232

224-
// TODO-AUDIO: Should accept sampleRate
225233
torch::Tensor getFramesPlayedInRangeAudio(
226234
double startSeconds,
227235
std::optional<double> stopSecondsOptional = std::nullopt);
@@ -343,6 +351,7 @@ class VideoDecoder {
343351
int64_t lastDecodedAvFramePts = 0;
344352
int64_t lastDecodedAvFrameDuration = 0;
345353
VideoStreamOptions videoStreamOptions;
354+
AudioStreamOptions audioStreamOptions;
346355

347356
// color-conversion fields. Only one of FilterGraphContext and
348357
// UniqueSwsContext should be non-null.

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
3434
m.def(
3535
"add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None) -> ()");
3636
m.def(
37-
"add_audio_stream(Tensor(a!) decoder, *, int? stream_index=None) -> ()");
37+
"add_audio_stream(Tensor(a!) decoder, *, int? stream_index=None, int? sample_rate=None) -> ()");
3838
m.def("seek_to_pts(Tensor(a!) decoder, float seconds) -> ()");
3939
m.def("get_next_frame(Tensor(a!) decoder) -> (Tensor, Tensor, Tensor)");
4040
m.def(
@@ -213,9 +213,13 @@ void _add_video_stream(
213213

214214
void add_audio_stream(
215215
at::Tensor& decoder,
216-
std::optional<int64_t> stream_index) {
216+
std::optional<int64_t> stream_index,
217+
std::optional<int64_t> sample_rate) {
218+
VideoDecoder::AudioStreamOptions audioStreamOptions;
219+
audioStreamOptions.sampleRate = sample_rate;
220+
217221
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
218-
videoDecoder->addAudioStream(stream_index.value_or(-1));
222+
videoDecoder->addAudioStream(stream_index.value_or(-1), audioStreamOptions);
219223
}
220224

221225
void seek_to_pts(at::Tensor& decoder, double seconds) {

src/torchcodec/decoders/_core/VideoDecoderOps.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ void _add_video_stream(
5050

5151
void add_audio_stream(
5252
at::Tensor& decoder,
53-
std::optional<int64_t> stream_index = std::nullopt);
53+
std::optional<int64_t> stream_index = std::nullopt,
54+
std::optional<int64_t> sample_rate = std::nullopt);
5455

5556
// Seek to a particular presentation timestamp in the video in seconds.
5657
void seek_to_pts(at::Tensor& decoder, double seconds);

0 commit comments

Comments
 (0)