Skip to content

Commit 99b0d4f

Browse files
committed
Merge branch 'main' of github.com:pytorch/torchcodec into approx
2 parents 737e1b6 + b7f8e0c commit 99b0d4f

File tree

7 files changed

+106
-135
lines changed

7 files changed

+106
-135
lines changed

.github/workflows/build_ffmpeg.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ jobs:
2929
matrix:
3030
ffmpeg-version: ["4.4.4", "5.1.4", "6.1.1", "7.0.1"]
3131
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
32+
permissions:
33+
id-token: write
34+
contents: read
3235
with:
3336
job-name: Build
3437
upload-artifact: ffmpeg-lgpl

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 85 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,8 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
600600
streamMetadata.maxPtsFromScan = std::max(
601601
streamMetadata.maxPtsFromScan.value_or(INT64_MIN),
602602
packet->pts + packet->duration);
603+
streamMetadata.numFramesFromScan =
604+
streamMetadata.numFramesFromScan.value_or(0) + 1;
603605

604606
// Note that we set the other value in this struct, nextPts, only after
605607
// we have scanned all packets and sorted by pts.
@@ -612,19 +614,20 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
612614

613615
// Set all per-stream metadata that requires knowing the content of all
614616
// packets.
615-
for (size_t i = 0; i < containerMetadata_.streams.size(); ++i) {
616-
auto& streamMetadata = containerMetadata_.streams[i];
617-
auto stream = formatContext_->streams[i];
617+
for (size_t streamIndex = 0; streamIndex < containerMetadata_.streams.size();
618+
++streamIndex) {
619+
auto& streamMetadata = containerMetadata_.streams[streamIndex];
620+
auto avStream = formatContext_->streams[streamIndex];
618621

619-
streamMetadata.numFramesFromScan = streams_[i].allFrames.size();
622+
streamMetadata.numFramesFromScan = streams_[streamIndex].allFrames.size();
620623

621624
if (streamMetadata.minPtsFromScan.has_value()) {
622625
streamMetadata.minPtsSecondsFromScan =
623-
*streamMetadata.minPtsFromScan * av_q2d(stream->time_base);
626+
*streamMetadata.minPtsFromScan * av_q2d(avStream->time_base);
624627
}
625628
if (streamMetadata.maxPtsFromScan.has_value()) {
626629
streamMetadata.maxPtsSecondsFromScan =
627-
*streamMetadata.maxPtsFromScan * av_q2d(stream->time_base);
630+
*streamMetadata.maxPtsFromScan * av_q2d(avStream->time_base);
628631
}
629632
}
630633

@@ -638,23 +641,23 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
638641
}
639642

640643
// Sort all frames by their pts.
641-
for (auto& [streamIndex, stream] : streams_) {
644+
for (auto& [streamIndex, streamInfo] : streams_) {
642645
std::sort(
643-
stream.keyFrames.begin(),
644-
stream.keyFrames.end(),
646+
streamInfo.keyFrames.begin(),
647+
streamInfo.keyFrames.end(),
645648
[](const FrameInfo& frameInfo1, const FrameInfo& frameInfo2) {
646649
return frameInfo1.pts < frameInfo2.pts;
647650
});
648651
std::sort(
649-
stream.allFrames.begin(),
650-
stream.allFrames.end(),
652+
streamInfo.allFrames.begin(),
653+
streamInfo.allFrames.end(),
651654
[](const FrameInfo& frameInfo1, const FrameInfo& frameInfo2) {
652655
return frameInfo1.pts < frameInfo2.pts;
653656
});
654657

655-
for (size_t i = 0; i < stream.allFrames.size(); ++i) {
656-
if (i + 1 < stream.allFrames.size()) {
657-
stream.allFrames[i].nextPts = stream.allFrames[i + 1].pts;
658+
for (size_t i = 0; i < streamInfo.allFrames.size(); ++i) {
659+
if (i + 1 < streamInfo.allFrames.size()) {
660+
streamInfo.allFrames[i].nextPts = streamInfo.allFrames[i + 1].pts;
658661
}
659662
}
660663
}
@@ -911,11 +914,9 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
911914
AVFrame* frame = rawOutput.frame.get();
912915
output.streamIndex = streamIndex;
913916
auto& streamInfo = streams_[streamIndex];
914-
output.streamType = streams_[streamIndex].stream->codecpar->codec_type;
915-
output.pts = frame->pts;
917+
TORCH_CHECK(streamInfo.stream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO);
916918
output.ptsSeconds =
917919
ptsToSeconds(frame->pts, formatContext_->streams[streamIndex]->time_base);
918-
output.duration = getDuration(frame);
919920
output.durationSeconds = ptsToSeconds(
920921
getDuration(frame), formatContext_->streams[streamIndex]->time_base);
921922
// TODO: we should fold preAllocatedOutputTensor into RawDecodedOutput.
@@ -972,86 +973,78 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
972973
}
973974

974975
torch::Tensor outputTensor;
975-
if (output.streamType == AVMEDIA_TYPE_VIDEO) {
976-
// We need to compare the current frame context with our previous frame
977-
// context. If they are different, then we need to re-create our colorspace
978-
// conversion objects. We create our colorspace conversion objects late so
979-
// that we don't have to depend on the unreliable metadata in the header.
980-
// And we sometimes re-create them because it's possible for frame
981-
// resolution to change mid-stream. Finally, we want to reuse the colorspace
982-
// conversion objects as much as possible for performance reasons.
983-
enum AVPixelFormat frameFormat =
984-
static_cast<enum AVPixelFormat>(frame->format);
985-
auto frameContext = DecodedFrameContext{
986-
frame->width,
987-
frame->height,
988-
frameFormat,
989-
expectedOutputWidth,
990-
expectedOutputHeight};
976+
// We need to compare the current frame context with our previous frame
977+
// context. If they are different, then we need to re-create our colorspace
978+
// conversion objects. We create our colorspace conversion objects late so
979+
// that we don't have to depend on the unreliable metadata in the header.
980+
// And we sometimes re-create them because it's possible for frame
981+
// resolution to change mid-stream. Finally, we want to reuse the colorspace
982+
// conversion objects as much as possible for performance reasons.
983+
enum AVPixelFormat frameFormat =
984+
static_cast<enum AVPixelFormat>(frame->format);
985+
auto frameContext = DecodedFrameContext{
986+
frame->width,
987+
frame->height,
988+
frameFormat,
989+
expectedOutputWidth,
990+
expectedOutputHeight};
991991

992-
if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
993-
outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
994-
expectedOutputHeight, expectedOutputWidth, torch::kCPU));
992+
if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
993+
outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
994+
expectedOutputHeight, expectedOutputWidth, torch::kCPU));
995995

996-
if (!streamInfo.swsContext ||
997-
streamInfo.prevFrameContext != frameContext) {
998-
createSwsContext(streamInfo, frameContext, frame->colorspace);
999-
streamInfo.prevFrameContext = frameContext;
1000-
}
1001-
int resultHeight =
1002-
convertFrameToTensorUsingSwsScale(streamIndex, frame, outputTensor);
1003-
// If this check failed, it would mean that the frame wasn't reshaped to
1004-
// the expected height.
1005-
// TODO: Can we do the same check for width?
1006-
TORCH_CHECK(
1007-
resultHeight == expectedOutputHeight,
1008-
"resultHeight != expectedOutputHeight: ",
1009-
resultHeight,
1010-
" != ",
1011-
expectedOutputHeight);
996+
if (!streamInfo.swsContext || streamInfo.prevFrameContext != frameContext) {
997+
createSwsContext(streamInfo, frameContext, frame->colorspace);
998+
streamInfo.prevFrameContext = frameContext;
999+
}
1000+
int resultHeight =
1001+
convertFrameToTensorUsingSwsScale(streamIndex, frame, outputTensor);
1002+
// If this check failed, it would mean that the frame wasn't reshaped to
1003+
// the expected height.
1004+
// TODO: Can we do the same check for width?
1005+
TORCH_CHECK(
1006+
resultHeight == expectedOutputHeight,
1007+
"resultHeight != expectedOutputHeight: ",
1008+
resultHeight,
1009+
" != ",
1010+
expectedOutputHeight);
1011+
1012+
output.frame = outputTensor;
1013+
} else if (
1014+
streamInfo.colorConversionLibrary ==
1015+
ColorConversionLibrary::FILTERGRAPH) {
1016+
if (!streamInfo.filterState.filterGraph ||
1017+
streamInfo.prevFrameContext != frameContext) {
1018+
createFilterGraph(streamInfo, expectedOutputHeight, expectedOutputWidth);
1019+
streamInfo.prevFrameContext = frameContext;
1020+
}
1021+
outputTensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame);
10121022

1013-
output.frame = outputTensor;
1014-
} else if (
1015-
streamInfo.colorConversionLibrary ==
1016-
ColorConversionLibrary::FILTERGRAPH) {
1017-
if (!streamInfo.filterState.filterGraph ||
1018-
streamInfo.prevFrameContext != frameContext) {
1019-
createFilterGraph(
1020-
streamInfo, expectedOutputHeight, expectedOutputWidth);
1021-
streamInfo.prevFrameContext = frameContext;
1022-
}
1023-
outputTensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame);
1024-
1025-
// Similarly to above, if this check fails it means the frame wasn't
1026-
// reshaped to its expected dimensions by filtergraph.
1027-
auto shape = outputTensor.sizes();
1028-
TORCH_CHECK(
1029-
(shape.size() == 3) && (shape[0] == expectedOutputHeight) &&
1030-
(shape[1] == expectedOutputWidth) && (shape[2] == 3),
1031-
"Expected output tensor of shape ",
1032-
expectedOutputHeight,
1033-
"x",
1034-
expectedOutputWidth,
1035-
"x3, got ",
1036-
shape);
1037-
1038-
if (preAllocatedOutputTensor.has_value()) {
1039-
// We have already validated that preAllocatedOutputTensor and
1040-
// outputTensor have the same shape.
1041-
preAllocatedOutputTensor.value().copy_(outputTensor);
1042-
output.frame = preAllocatedOutputTensor.value();
1043-
} else {
1044-
output.frame = outputTensor;
1045-
}
1023+
// Similarly to above, if this check fails it means the frame wasn't
1024+
// reshaped to its expected dimensions by filtergraph.
1025+
auto shape = outputTensor.sizes();
1026+
TORCH_CHECK(
1027+
(shape.size() == 3) && (shape[0] == expectedOutputHeight) &&
1028+
(shape[1] == expectedOutputWidth) && (shape[2] == 3),
1029+
"Expected output tensor of shape ",
1030+
expectedOutputHeight,
1031+
"x",
1032+
expectedOutputWidth,
1033+
"x3, got ",
1034+
shape);
1035+
1036+
if (preAllocatedOutputTensor.has_value()) {
1037+
// We have already validated that preAllocatedOutputTensor and
1038+
// outputTensor have the same shape.
1039+
preAllocatedOutputTensor.value().copy_(outputTensor);
1040+
output.frame = preAllocatedOutputTensor.value();
10461041
} else {
1047-
throw std::runtime_error(
1048-
"Invalid color conversion library: " +
1049-
std::to_string(static_cast<int>(streamInfo.colorConversionLibrary)));
1042+
output.frame = outputTensor;
10501043
}
1051-
} else if (output.streamType == AVMEDIA_TYPE_AUDIO) {
1052-
// TODO: https://github.com/pytorch-labs/torchcodec/issues/85 implement
1053-
// audio decoding.
1054-
throw std::runtime_error("Audio is not supported yet.");
1044+
} else {
1045+
throw std::runtime_error(
1046+
"Invalid color conversion library: " +
1047+
std::to_string(static_cast<int>(streamInfo.colorConversionLibrary)));
10551048
}
10561049
}
10571050

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -164,48 +164,22 @@ class VideoDecoder {
164164
// Calling getNextFrameNoDemuxInternal() will return the first frame at
165165
// or after this position.
166166
void setCursorPtsInSeconds(double seconds);
167-
// This is an internal structure that is used to store the decoded output
168-
// from decoding a frame through color conversion. Example usage is:
169-
//
170-
// RawDecodedOutput rawOutput = getDecodedOutputWithFilter();
171-
// // Now allocate a single tensor or a batch tensor.
172-
// torch::Tensor userOutput = torch::empty(...);
173-
// // Now fill in `data` and `size`.
174-
// rawOutput.data = userOutput.data_ptr();
175-
// // Now run the color conversion.
176-
// convertFrameToBufferUsingSwsScale(rawOutput);
177-
//
178-
// This structure ensures we always keep the streamIndex and frame together
179-
// with the data output. Note that AVFrame itself doesn't retain the
180-
// streamIndex.
167+
// This structure ensures we always keep the streamIndex and AVFrame together
168+
// Note that AVFrame itself doesn't retain the streamIndex.
181169
struct RawDecodedOutput {
182170
// The actual decoded output as a unique pointer to an AVFrame.
183171
UniqueAVFrame frame;
184172
// The stream index of the decoded frame.
185173
int streamIndex;
186-
// This is an unowned pointer that we copy the frame data to after color
187-
// conversion.
188-
// For a single tensor this points to the start of data_ptr. For a batch
189-
// tensor it may point to the middle of the allocated batch tensor.
190-
void* data = nullptr;
191-
// We carry around the size to ensure we don't stomp on memory while doing
192-
// color conversion.
193-
size_t size = 0;
194174
};
195175
struct DecodedOutput {
196176
// The actual decoded output as a Tensor.
197177
torch::Tensor frame;
198-
// Could be AVMEDIA_TYPE_VIDEO or AVMEDIA_TYPE_AUDIO.
199-
AVMediaType streamType;
200178
// The stream index of the decoded frame. Used to distinguish
201179
// between streams that are of the same type.
202180
int streamIndex;
203-
// The presentation timestamp of the decoded frame in time base.
204-
int64_t pts;
205181
// The presentation timestamp of the decoded frame in seconds.
206182
double ptsSeconds;
207-
// The duration of the decoded frame in time base.
208-
int64_t duration;
209183
// The duration of the decoded frame in seconds.
210184
double durationSeconds;
211185
};

src/torchcodec/decoders/_core/_metadata.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,17 +110,27 @@ def average_fps(self) -> Optional[float]:
110110

111111
@property
112112
def begin_stream_seconds(self) -> float:
113-
"""TODO."""
113+
"""Beginning of the stream, in seconds (float). Conceptually, this
114+
corresponds to the first frame's :term:`pts`. If
115+
``begin_stream_seconds_from_content`` is not None, then it is returned.
116+
Otherwise, this value is 0.
117+
"""
114118
if self.begin_stream_seconds_from_content is None:
115119
return 0
116-
return self.begin_stream_seconds_from_content
120+
else:
121+
return self.begin_stream_seconds_from_content
117122

118123
@property
119124
def end_stream_seconds(self) -> Optional[float]:
120-
"""TODO."""
125+
"""End of the stream, in seconds (float or None).
126+
Conceptually, this corresponds to last_frame.pts + last_frame.duration.
127+
If ``end_stream_seconds_from_content`` is not None, then that value is
128+
returned. Otherwise, returns ``duration_seconds``.
129+
"""
121130
if self.end_stream_seconds_from_content is None:
122131
return self.duration_seconds
123-
return self.end_stream_seconds_from_content
132+
else:
133+
return self.end_stream_seconds_from_content
124134

125135
def __repr__(self):
126136
# Overridden because properites are not printed by default.

test/decoders/VideoDecoderTest.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,12 +174,10 @@ TEST_P(VideoDecoderTest, ReturnsFirstTwoFramesOfVideo) {
174174
torch::Tensor tensor0FromOurDecoder = output.frame;
175175
EXPECT_EQ(tensor0FromOurDecoder.sizes(), std::vector<long>({3, 270, 480}));
176176
EXPECT_EQ(output.ptsSeconds, 0.0);
177-
EXPECT_EQ(output.pts, 0);
178177
output = ourDecoder->getNextFrameNoDemux();
179178
torch::Tensor tensor1FromOurDecoder = output.frame;
180179
EXPECT_EQ(tensor1FromOurDecoder.sizes(), std::vector<long>({3, 270, 480}));
181180
EXPECT_EQ(output.ptsSeconds, 1'001. / 30'000);
182-
EXPECT_EQ(output.pts, 1001);
183181

184182
torch::Tensor tensor0FromFFMPEG =
185183
readTensorFromDisk("nasa_13013.mp4.stream3.frame000000.pt");

test/decoders/test_video_decoder_ops.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -528,9 +528,7 @@ def test_color_conversion_library_with_scaling(
528528
if height_scaling_factor != 1.0:
529529
assert target_height != input_video.height
530530

531-
filtergraph_decoder = create_from_file(
532-
str(input_video.path)
533-
)
531+
filtergraph_decoder = create_from_file(str(input_video.path))
534532
_add_video_stream(
535533
filtergraph_decoder,
536534
width=target_width,
@@ -539,9 +537,7 @@ def test_color_conversion_library_with_scaling(
539537
)
540538
filtergraph_frame0, _, _ = get_next_frame(filtergraph_decoder)
541539

542-
swscale_decoder = create_from_file(
543-
str(input_video.path)
544-
)
540+
swscale_decoder = create_from_file(str(input_video.path))
545541
_add_video_stream(
546542
swscale_decoder,
547543
width=target_width,

test/samplers/test_samplers.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -592,17 +592,14 @@ def restore_metadata():
592592
with restore_metadata():
593593
decoder.metadata.end_stream_seconds_from_content = None
594594
decoder.metadata.duration_seconds_from_header = None
595-
decoder.metadata.duration_seconds_from_content = None
596595
with pytest.raises(
597596
ValueError, match="Could not infer stream end from video metadata"
598597
):
599598
sampler(decoder)
600599

601600
with restore_metadata():
602-
decoder.metadata.begin_stream_seconds_from_content = None
603601
decoder.metadata.end_stream_seconds_from_content = None
604602
decoder.metadata.average_fps_from_header = None
605-
decoder.metadata.duration_seconds_from_header = None
606603
with pytest.raises(ValueError, match="Could not infer average fps"):
607604
sampler(decoder)
608605

0 commit comments

Comments
 (0)