Skip to content

Commit 1b362ec

Browse files
committed
Add basic decoding
1 parent ace0bd4 commit 1b362ec

File tree

4 files changed

+54
-11
lines changed

4 files changed

+54
-11
lines changed

src/torchcodec/decoders/_core/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ set(CMAKE_CXX_STANDARD 17)
44
set(CMAKE_CXX_STANDARD_REQUIRED ON)
55

66
find_package(Torch REQUIRED)
7-
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Werror ${TORCH_CXX_FLAGS}")
7+
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Werror ${TORCH_CXX_FLAGS}")
8+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra ${TORCH_CXX_FLAGS}")
89
find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development)
910

1011
function(make_torchcodec_library library_name ffmpeg_target)

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,9 @@ void VideoDecoder::addStream(
425425
TORCH_CHECK(
426426
activeStreamIndex_ == NO_ACTIVE_STREAM,
427427
"Can only add one single stream.");
428+
TORCH_CHECK(
429+
mediaType == AVMEDIA_TYPE_VIDEO || mediaType == AVMEDIA_TYPE_AUDIO,
430+
"Can only add video or audio streams.");
428431
TORCH_CHECK(formatContext_.get() != nullptr);
429432

430433
AVCodecOnlyUseForCallingAVFindBestStream avCodec = nullptr;
@@ -448,9 +451,10 @@ void VideoDecoder::addStream(
448451

449452
// This should never happen, checking just to be safe.
450453
TORCH_CHECK(
451-
streamInfo.stream->codecpar->codec_type == mediaType,
452-
"FFmpeg found stream with index ", activeStreamIndex_, " which is of the wrong media type.");
453-
454+
streamInfo.stream->codecpar->codec_type == mediaType,
455+
"FFmpeg found stream with index ",
456+
activeStreamIndex_,
457+
" which is of the wrong media type.");
454458

455459
if (mediaType == AVMEDIA_TYPE_VIDEO &&
456460
videoStreamOptions.device.type() == torch::kCUDA) {
@@ -1076,8 +1080,10 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
10761080
avFrame->pts, formatContext_->streams[streamIndex]->time_base);
10771081
frameOutput.durationSeconds = ptsToSeconds(
10781082
getDuration(avFrame), formatContext_->streams[streamIndex]->time_base);
1079-
// TODO: we should fold preAllocatedOutputTensor into AVFrameStream.
1080-
if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) {
1083+
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
1084+
// TODO: handle preAllocatedTensor for audio
1085+
convertAudioAVFrameToFrameOutputOnCPU(avFrameStream, frameOutput);
1086+
} else if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) {
10811087
convertAVFrameToFrameOutputOnCPU(
10821088
avFrameStream, frameOutput, preAllocatedOutputTensor);
10831089
} else if (streamInfo.videoStreamOptions.device.type() == torch::kCUDA) {
@@ -1253,6 +1259,39 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(
12531259
filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
12541260
}
12551261

1262+
void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
1263+
VideoDecoder::AVFrameStream& avFrameStream,
1264+
FrameOutput& frameOutput) {
1265+
AVFrame* avFrame = avFrameStream.avFrame.get();
1266+
1267+
auto numSamples = avFrame->nb_samples; // per channel
1268+
auto numChannels =
1269+
avFrame->ch_layout.nb_channels; // TODO handle other ffmpeg versions
1270+
1271+
// TODO: dtype should be format-dependent
1272+
torch::Tensor data = torch::empty({numChannels, numSamples}, torch::kFloat32);
1273+
1274+
AVSampleFormat format = static_cast<AVSampleFormat>(avFrame->format);
1275+
// TODO Implement all formats
1276+
switch (format) {
1277+
case AV_SAMPLE_FMT_FLTP: {
1278+
uint8_t* pData = static_cast<uint8_t*>(data.data_ptr());
1279+
for (auto channel = 0; channel < numChannels; ++channel) {
1280+
auto numBytesToCopy = numSamples * av_get_bytes_per_sample(format);
1281+
memcpy(pData, avFrame->extended_data[channel], numBytesToCopy);
1282+
pData += numBytesToCopy;
1283+
}
1284+
break;
1285+
}
1286+
default:
1287+
TORCH_CHECK(
1288+
false,
1289+
"Unsupported audio format (yet!): ",
1290+
av_get_sample_fmt_name(format));
1291+
}
1292+
frameOutput.data = data;
1293+
}
1294+
12561295
// --------------------------------------------------------------------------
12571296
// OUTPUT ALLOCATION AND SHAPE CONVERSION
12581297
// --------------------------------------------------------------------------
@@ -1298,6 +1337,10 @@ torch::Tensor allocateEmptyHWCTensor(
12981337
// Calling permute() is guaranteed to return a view as per the docs:
12991338
// https://pytorch.org/docs/stable/generated/torch.permute.html
13001339
torch::Tensor VideoDecoder::maybePermuteHWC2CHW(torch::Tensor& hwcTensor) {
1340+
if (streamInfos_[activeStreamIndex_].avMediaType == AVMEDIA_TYPE_AUDIO) {
1341+
// TODO: Is this really how we want to handle audio?
1342+
return hwcTensor;
1343+
}
13011344
if (streamInfos_[activeStreamIndex_].videoStreamOptions.dimensionOrder ==
13021345
"NHWC") {
13031346
return hwcTensor;

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,10 @@ class VideoDecoder {
381381
FrameOutput& frameOutput,
382382
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
383383

384+
void convertAudioAVFrameToFrameOutputOnCPU(
385+
AVFrameStream& avFrameStream,
386+
FrameOutput& frameOutput);
387+
384388
torch::Tensor convertAVFrameToTensorUsingFilterGraph(const AVFrame* avFrame);
385389

386390
int convertAVFrameToTensorUsingSwsScale(

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -245,11 +245,6 @@ OpsFrameOutput get_next_frame(at::Tensor& decoder) {
245245
} catch (const VideoDecoder::EndOfFileException& e) {
246246
C10_THROW_ERROR(IndexError, e.what());
247247
}
248-
if (result.data.sizes().size() != 3) {
249-
throw std::runtime_error(
250-
"image_size is unexpected. Expected 3, got: " +
251-
std::to_string(result.data.sizes().size()));
252-
}
253248
return makeOpsFrameOutput(result);
254249
}
255250

0 commit comments

Comments
 (0)