Skip to content

Commit 6be8abc

Browse files
committed
Merge branch 'main' of github.com:pytorch/torchcodec into avcodecptr
2 parents 7b3be11 + b7f8e0c commit 6be8abc

File tree

4 files changed

+94
-127
lines changed

4 files changed

+94
-127
lines changed

.github/workflows/build_ffmpeg.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ jobs:
2929
matrix:
3030
ffmpeg-version: ["4.4.4", "5.1.4", "6.1.1", "7.0.1"]
3131
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
32+
permissions:
33+
id-token: write
34+
contents: read
3235
with:
3336
job-name: Build
3437
upload-artifact: ffmpeg-lgpl

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 89 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -563,13 +563,14 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
563563
if (packet->flags & AV_PKT_FLAG_DISCARD) {
564564
continue;
565565
}
566-
auto& stream = containerMetadata_.streams[streamIndex];
567-
stream.minPtsFromScan =
568-
std::min(stream.minPtsFromScan.value_or(INT64_MAX), packet->pts);
569-
stream.maxPtsFromScan = std::max(
570-
stream.maxPtsFromScan.value_or(INT64_MIN),
566+
auto& streamMetadata = containerMetadata_.streams[streamIndex];
567+
streamMetadata.minPtsFromScan = std::min(
568+
streamMetadata.minPtsFromScan.value_or(INT64_MAX), packet->pts);
569+
streamMetadata.maxPtsFromScan = std::max(
570+
streamMetadata.maxPtsFromScan.value_or(INT64_MIN),
571571
packet->pts + packet->duration);
572-
stream.numFramesFromScan = stream.numFramesFromScan.value_or(0) + 1;
572+
streamMetadata.numFramesFromScan =
573+
streamMetadata.numFramesFromScan.value_or(0) + 1;
573574

574575
FrameInfo frameInfo;
575576
frameInfo.pts = packet->pts;
@@ -579,16 +580,17 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
579580
}
580581
streams_[streamIndex].allFrames.push_back(frameInfo);
581582
}
582-
for (size_t i = 0; i < containerMetadata_.streams.size(); ++i) {
583-
auto& streamMetadata = containerMetadata_.streams[i];
584-
auto stream = formatContext_->streams[i];
583+
for (size_t streamIndex = 0; streamIndex < containerMetadata_.streams.size();
584+
++streamIndex) {
585+
auto& streamMetadata = containerMetadata_.streams[streamIndex];
586+
auto avStream = formatContext_->streams[streamIndex];
585587
if (streamMetadata.minPtsFromScan.has_value()) {
586588
streamMetadata.minPtsSecondsFromScan =
587-
*streamMetadata.minPtsFromScan * av_q2d(stream->time_base);
589+
*streamMetadata.minPtsFromScan * av_q2d(avStream->time_base);
588590
}
589591
if (streamMetadata.maxPtsFromScan.has_value()) {
590592
streamMetadata.maxPtsSecondsFromScan =
591-
*streamMetadata.maxPtsFromScan * av_q2d(stream->time_base);
593+
*streamMetadata.maxPtsFromScan * av_q2d(avStream->time_base);
592594
}
593595
}
594596
int ffmepgStatus =
@@ -598,23 +600,23 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
598600
"Could not seek file to pts=0: " +
599601
getFFMPEGErrorStringFromErrorCode(ffmepgStatus));
600602
}
601-
for (auto& [streamIndex, stream] : streams_) {
603+
for (auto& [streamIndex, streamInfo] : streams_) {
602604
std::sort(
603-
stream.keyFrames.begin(),
604-
stream.keyFrames.end(),
605+
streamInfo.keyFrames.begin(),
606+
streamInfo.keyFrames.end(),
605607
[](const FrameInfo& frameInfo1, const FrameInfo& frameInfo2) {
606608
return frameInfo1.pts < frameInfo2.pts;
607609
});
608610
std::sort(
609-
stream.allFrames.begin(),
610-
stream.allFrames.end(),
611+
streamInfo.allFrames.begin(),
612+
streamInfo.allFrames.end(),
611613
[](const FrameInfo& frameInfo1, const FrameInfo& frameInfo2) {
612614
return frameInfo1.pts < frameInfo2.pts;
613615
});
614616

615-
for (size_t i = 0; i < stream.allFrames.size(); ++i) {
616-
if (i + 1 < stream.allFrames.size()) {
617-
stream.allFrames[i].nextPts = stream.allFrames[i + 1].pts;
617+
for (size_t i = 0; i < streamInfo.allFrames.size(); ++i) {
618+
if (i + 1 < streamInfo.allFrames.size()) {
619+
streamInfo.allFrames[i].nextPts = streamInfo.allFrames[i + 1].pts;
618620
}
619621
}
620622
}
@@ -870,11 +872,9 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
870872
AVFrame* frame = rawOutput.frame.get();
871873
output.streamIndex = streamIndex;
872874
auto& streamInfo = streams_[streamIndex];
873-
output.streamType = streams_[streamIndex].stream->codecpar->codec_type;
874-
output.pts = frame->pts;
875+
TORCH_CHECK(streamInfo.stream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO);
875876
output.ptsSeconds =
876877
ptsToSeconds(frame->pts, formatContext_->streams[streamIndex]->time_base);
877-
output.duration = getDuration(frame);
878878
output.durationSeconds = ptsToSeconds(
879879
getDuration(frame), formatContext_->streams[streamIndex]->time_base);
880880
// TODO: we should fold preAllocatedOutputTensor into RawDecodedOutput.
@@ -931,86 +931,78 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
931931
}
932932

933933
torch::Tensor outputTensor;
934-
if (output.streamType == AVMEDIA_TYPE_VIDEO) {
935-
// We need to compare the current frame context with our previous frame
936-
// context. If they are different, then we need to re-create our colorspace
937-
// conversion objects. We create our colorspace conversion objects late so
938-
// that we don't have to depend on the unreliable metadata in the header.
939-
// And we sometimes re-create them because it's possible for frame
940-
// resolution to change mid-stream. Finally, we want to reuse the colorspace
941-
// conversion objects as much as possible for performance reasons.
942-
enum AVPixelFormat frameFormat =
943-
static_cast<enum AVPixelFormat>(frame->format);
944-
auto frameContext = DecodedFrameContext{
945-
frame->width,
946-
frame->height,
947-
frameFormat,
948-
expectedOutputWidth,
949-
expectedOutputHeight};
934+
// We need to compare the current frame context with our previous frame
935+
// context. If they are different, then we need to re-create our colorspace
936+
// conversion objects. We create our colorspace conversion objects late so
937+
// that we don't have to depend on the unreliable metadata in the header.
938+
// And we sometimes re-create them because it's possible for frame
939+
// resolution to change mid-stream. Finally, we want to reuse the colorspace
940+
// conversion objects as much as possible for performance reasons.
941+
enum AVPixelFormat frameFormat =
942+
static_cast<enum AVPixelFormat>(frame->format);
943+
auto frameContext = DecodedFrameContext{
944+
frame->width,
945+
frame->height,
946+
frameFormat,
947+
expectedOutputWidth,
948+
expectedOutputHeight};
950949

951-
if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
952-
outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
953-
expectedOutputHeight, expectedOutputWidth, torch::kCPU));
950+
if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
951+
outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
952+
expectedOutputHeight, expectedOutputWidth, torch::kCPU));
954953

955-
if (!streamInfo.swsContext ||
956-
streamInfo.prevFrameContext != frameContext) {
957-
createSwsContext(streamInfo, frameContext, frame->colorspace);
958-
streamInfo.prevFrameContext = frameContext;
959-
}
960-
int resultHeight =
961-
convertFrameToTensorUsingSwsScale(streamIndex, frame, outputTensor);
962-
// If this check failed, it would mean that the frame wasn't reshaped to
963-
// the expected height.
964-
// TODO: Can we do the same check for width?
965-
TORCH_CHECK(
966-
resultHeight == expectedOutputHeight,
967-
"resultHeight != expectedOutputHeight: ",
968-
resultHeight,
969-
" != ",
970-
expectedOutputHeight);
954+
if (!streamInfo.swsContext || streamInfo.prevFrameContext != frameContext) {
955+
createSwsContext(streamInfo, frameContext, frame->colorspace);
956+
streamInfo.prevFrameContext = frameContext;
957+
}
958+
int resultHeight =
959+
convertFrameToTensorUsingSwsScale(streamIndex, frame, outputTensor);
960+
// If this check failed, it would mean that the frame wasn't reshaped to
961+
// the expected height.
962+
// TODO: Can we do the same check for width?
963+
TORCH_CHECK(
964+
resultHeight == expectedOutputHeight,
965+
"resultHeight != expectedOutputHeight: ",
966+
resultHeight,
967+
" != ",
968+
expectedOutputHeight);
969+
970+
output.frame = outputTensor;
971+
} else if (
972+
streamInfo.colorConversionLibrary ==
973+
ColorConversionLibrary::FILTERGRAPH) {
974+
if (!streamInfo.filterState.filterGraph ||
975+
streamInfo.prevFrameContext != frameContext) {
976+
createFilterGraph(streamInfo, expectedOutputHeight, expectedOutputWidth);
977+
streamInfo.prevFrameContext = frameContext;
978+
}
979+
outputTensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame);
971980

972-
output.frame = outputTensor;
973-
} else if (
974-
streamInfo.colorConversionLibrary ==
975-
ColorConversionLibrary::FILTERGRAPH) {
976-
if (!streamInfo.filterState.filterGraph ||
977-
streamInfo.prevFrameContext != frameContext) {
978-
createFilterGraph(
979-
streamInfo, expectedOutputHeight, expectedOutputWidth);
980-
streamInfo.prevFrameContext = frameContext;
981-
}
982-
outputTensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame);
983-
984-
// Similarly to above, if this check fails it means the frame wasn't
985-
// reshaped to its expected dimensions by filtergraph.
986-
auto shape = outputTensor.sizes();
987-
TORCH_CHECK(
988-
(shape.size() == 3) && (shape[0] == expectedOutputHeight) &&
989-
(shape[1] == expectedOutputWidth) && (shape[2] == 3),
990-
"Expected output tensor of shape ",
991-
expectedOutputHeight,
992-
"x",
993-
expectedOutputWidth,
994-
"x3, got ",
995-
shape);
996-
997-
if (preAllocatedOutputTensor.has_value()) {
998-
// We have already validated that preAllocatedOutputTensor and
999-
// outputTensor have the same shape.
1000-
preAllocatedOutputTensor.value().copy_(outputTensor);
1001-
output.frame = preAllocatedOutputTensor.value();
1002-
} else {
1003-
output.frame = outputTensor;
1004-
}
981+
// Similarly to above, if this check fails it means the frame wasn't
982+
// reshaped to its expected dimensions by filtergraph.
983+
auto shape = outputTensor.sizes();
984+
TORCH_CHECK(
985+
(shape.size() == 3) && (shape[0] == expectedOutputHeight) &&
986+
(shape[1] == expectedOutputWidth) && (shape[2] == 3),
987+
"Expected output tensor of shape ",
988+
expectedOutputHeight,
989+
"x",
990+
expectedOutputWidth,
991+
"x3, got ",
992+
shape);
993+
994+
if (preAllocatedOutputTensor.has_value()) {
995+
// We have already validated that preAllocatedOutputTensor and
996+
// outputTensor have the same shape.
997+
preAllocatedOutputTensor.value().copy_(outputTensor);
998+
output.frame = preAllocatedOutputTensor.value();
1005999
} else {
1006-
throw std::runtime_error(
1007-
"Invalid color conversion library: " +
1008-
std::to_string(static_cast<int>(streamInfo.colorConversionLibrary)));
1000+
output.frame = outputTensor;
10091001
}
1010-
} else if (output.streamType == AVMEDIA_TYPE_AUDIO) {
1011-
// TODO: https://github.com/pytorch-labs/torchcodec/issues/85 implement
1012-
// audio decoding.
1013-
throw std::runtime_error("Audio is not supported yet.");
1002+
} else {
1003+
throw std::runtime_error(
1004+
"Invalid color conversion library: " +
1005+
std::to_string(static_cast<int>(streamInfo.colorConversionLibrary)));
10141006
}
10151007
}
10161008

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -160,48 +160,22 @@ class VideoDecoder {
160160
// Calling getNextFrameOutputNoDemuxInternal() will return the first frame at
161161
// or after this position.
162162
void setCursorPtsInSeconds(double seconds);
163-
// This is an internal structure that is used to store the decoded output
164-
// from decoding a frame through color conversion. Example usage is:
165-
//
166-
// RawDecodedOutput rawOutput = getDecodedOutputWithFilter();
167-
// // Now allocate a single tensor or a batch tensor.
168-
// torch::Tensor userOutput = torch::empty(...);
169-
// // Now fill in `data` and `size`.
170-
// rawOutput.data = userOutput.data_ptr();
171-
// // Now run the color conversion.
172-
// convertFrameToBufferUsingSwsScale(rawOutput);
173-
//
174-
// This structure ensures we always keep the streamIndex and frame together
175-
// with the data output. Note that AVFrame itself doesn't retain the
176-
// streamIndex.
163+
// This structure ensures we always keep the streamIndex and AVFrame together
164+
// Note that AVFrame itself doesn't retain the streamIndex.
177165
struct RawDecodedOutput {
178166
// The actual decoded output as a unique pointer to an AVFrame.
179167
UniqueAVFrame frame;
180168
// The stream index of the decoded frame.
181169
int streamIndex;
182-
// This is an unowned pointer that we copy the frame data to after color
183-
// conversion.
184-
// For a single tensor this points to the start of data_ptr. For a batch
185-
// tensor it may point to the middle of the allocated batch tensor.
186-
void* data = nullptr;
187-
// We carry around the size to ensure we don't stomp on memory while doing
188-
// color conversion.
189-
size_t size = 0;
190170
};
191171
struct DecodedOutput {
192172
// The actual decoded output as a Tensor.
193173
torch::Tensor frame;
194-
// Could be AVMEDIA_TYPE_VIDEO or AVMEDIA_TYPE_AUDIO.
195-
AVMediaType streamType;
196174
// The stream index of the decoded frame. Used to distinguish
197175
// between streams that are of the same type.
198176
int streamIndex;
199-
// The presentation timestamp of the decoded frame in time base.
200-
int64_t pts;
201177
// The presentation timestamp of the decoded frame in seconds.
202178
double ptsSeconds;
203-
// The duration of the decoded frame in time base.
204-
int64_t duration;
205179
// The duration of the decoded frame in seconds.
206180
double durationSeconds;
207181
};

test/decoders/VideoDecoderTest.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,10 @@ TEST_P(VideoDecoderTest, ReturnsFirstTwoFramesOfVideo) {
172172
torch::Tensor tensor0FromOurDecoder = output.frame;
173173
EXPECT_EQ(tensor0FromOurDecoder.sizes(), std::vector<long>({3, 270, 480}));
174174
EXPECT_EQ(output.ptsSeconds, 0.0);
175-
EXPECT_EQ(output.pts, 0);
176175
output = ourDecoder->getNextFrameNoDemux();
177176
torch::Tensor tensor1FromOurDecoder = output.frame;
178177
EXPECT_EQ(tensor1FromOurDecoder.sizes(), std::vector<long>({3, 270, 480}));
179178
EXPECT_EQ(output.ptsSeconds, 1'001. / 30'000);
180-
EXPECT_EQ(output.pts, 1001);
181179

182180
torch::Tensor tensor0FromFFMPEG =
183181
readTensorFromDisk("nasa_13013.mp4.stream3.frame000000.pt");

0 commit comments

Comments
 (0)