Skip to content

Commit a7c5711

Browse files
committed
Use allStreamMetadata
1 parent 11779a7 commit a7c5711

File tree

4 files changed

+40
-29
lines changed

4 files changed

+40
-29
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ void VideoDecoder::initializeDecoder() {
289289
containerMetadata_.numAudioStreams++;
290290
}
291291

292-
containerMetadata_.streamMetadatas.push_back(streamMetadata);
292+
containerMetadata_.allStreamMetadata.push_back(streamMetadata);
293293
}
294294

295295
if (formatContext_->duration > 0) {
@@ -487,7 +487,7 @@ void VideoDecoder::addVideoStreamDecoder(
487487
}
488488

489489
StreamMetadata& streamMetadata =
490-
containerMetadata_.streamMetadatas[streamIndex];
490+
containerMetadata_.allStreamMetadata[streamIndex];
491491
if (seekMode_ == SeekMode::approximate &&
492492
!streamMetadata.averageFps.has_value()) {
493493
throw std::runtime_error(
@@ -539,10 +539,11 @@ void VideoDecoder::addVideoStreamDecoder(
539539
void VideoDecoder::updateMetadataWithCodecContext(
540540
int streamIndex,
541541
AVCodecContext* codecContext) {
542-
containerMetadata_.streamMetadatas[streamIndex].width = codecContext->width;
543-
containerMetadata_.streamMetadatas[streamIndex].height = codecContext->height;
542+
containerMetadata_.allStreamMetadata[streamIndex].width = codecContext->width;
543+
containerMetadata_.allStreamMetadata[streamIndex].height =
544+
codecContext->height;
544545
auto codedId = codecContext->codec_id;
545-
containerMetadata_.streamMetadatas[streamIndex].codecName =
546+
containerMetadata_.allStreamMetadata[streamIndex].codecName =
546547
std::string(avcodec_get_name(codedId));
547548
}
548549

@@ -603,7 +604,7 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
603604
// We got a valid packet. Let's figure out what stream it belongs to and
604605
// record its relevant metadata.
605606
int streamIndex = packet->stream_index;
606-
auto& streamMetadata = containerMetadata_.streamMetadatas[streamIndex];
607+
auto& streamMetadata = containerMetadata_.allStreamMetadata[streamIndex];
607608
streamMetadata.minPtsFromScan = std::min(
608609
streamMetadata.minPtsFromScan.value_or(INT64_MAX), packet->pts);
609610
streamMetadata.maxPtsFromScan = std::max(
@@ -624,9 +625,9 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
624625
// Set all per-stream metadata that requires knowing the content of all
625626
// packets.
626627
for (size_t streamIndex = 0;
627-
streamIndex < containerMetadata_.streamMetadatas.size();
628+
streamIndex < containerMetadata_.allStreamMetadata.size();
628629
++streamIndex) {
629-
auto& streamMetadata = containerMetadata_.streamMetadatas[streamIndex];
630+
auto& streamMetadata = containerMetadata_.allStreamMetadata[streamIndex];
630631
auto avStream = formatContext_->streams[streamIndex];
631632

632633
streamMetadata.numFramesFromScan =
@@ -1104,7 +1105,8 @@ VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux(
11041105
}
11051106

11061107
void VideoDecoder::validateUserProvidedStreamIndex(int streamIndex) {
1107-
int streamsSize = static_cast<int>(containerMetadata_.streamMetadatas.size());
1108+
int streamsSize =
1109+
static_cast<int>(containerMetadata_.allStreamMetadata.size());
11081110
TORCH_CHECK(
11091111
streamIndex >= 0 && streamIndex < streamsSize,
11101112
"Invalid stream index=" + std::to_string(streamIndex) +
@@ -1243,7 +1245,8 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal(
12431245
validateUserProvidedStreamIndex(streamIndex);
12441246

12451247
const auto& streamInfo = streamInfos_[streamIndex];
1246-
const auto& streamMetadata = containerMetadata_.streamMetadatas[streamIndex];
1248+
const auto& streamMetadata =
1249+
containerMetadata_.allStreamMetadata[streamIndex];
12471250
validateFrameIndex(streamMetadata, frameIndex);
12481251

12491252
int64_t pts = getPts(streamInfo, streamMetadata, frameIndex);
@@ -1275,7 +1278,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
12751278
});
12761279
}
12771280

1278-
const auto& streamMetadata = containerMetadata_.streamMetadatas[streamIndex];
1281+
const auto& streamMetadata =
1282+
containerMetadata_.allStreamMetadata[streamIndex];
12791283
const auto& streamInfo = streamInfos_[streamIndex];
12801284
const auto& videoStreamOptions = streamInfo.videoStreamOptions;
12811285
BatchDecodedOutput output(
@@ -1313,7 +1317,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesPlayedByTimestamps(
13131317
const std::vector<double>& timestamps) {
13141318
validateUserProvidedStreamIndex(streamIndex);
13151319

1316-
const auto& streamMetadata = containerMetadata_.streamMetadatas[streamIndex];
1320+
const auto& streamMetadata =
1321+
containerMetadata_.allStreamMetadata[streamIndex];
13171322
const auto& streamInfo = streamInfos_[streamIndex];
13181323

13191324
double minSeconds = getMinSeconds(streamMetadata);
@@ -1347,7 +1352,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
13471352
int64_t step) {
13481353
validateUserProvidedStreamIndex(streamIndex);
13491354

1350-
const auto& streamMetadata = containerMetadata_.streamMetadatas[streamIndex];
1355+
const auto& streamMetadata =
1356+
containerMetadata_.allStreamMetadata[streamIndex];
13511357
const auto& streamInfo = streamInfos_[streamIndex];
13521358
int64_t numFrames = getNumFrames(streamMetadata);
13531359
TORCH_CHECK(
@@ -1381,7 +1387,8 @@ VideoDecoder::getFramesPlayedByTimestampInRange(
13811387
double stopSeconds) {
13821388
validateUserProvidedStreamIndex(streamIndex);
13831389

1384-
const auto& streamMetadata = containerMetadata_.streamMetadatas[streamIndex];
1390+
const auto& streamMetadata =
1391+
containerMetadata_.allStreamMetadata[streamIndex];
13851392
TORCH_CHECK(
13861393
startSeconds <= stopSeconds,
13871394
"Start seconds (" + std::to_string(startSeconds) +
@@ -1498,7 +1505,8 @@ double VideoDecoder::getPtsSecondsForFrame(
14981505
validateScannedAllStreams("getPtsSecondsForFrame");
14991506

15001507
const auto& streamInfo = streamInfos_[streamIndex];
1501-
const auto& streamMetadata = containerMetadata_.streamMetadatas[streamIndex];
1508+
const auto& streamMetadata =
1509+
containerMetadata_.allStreamMetadata[streamIndex];
15021510
validateFrameIndex(streamMetadata, frameIndex);
15031511

15041512
return ptsToSeconds(

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ class VideoDecoder {
104104
std::optional<int64_t> height;
105105
};
106106
struct ContainerMetadata {
107-
std::vector<StreamMetadata> streamMetadatas;
107+
std::vector<StreamMetadata> allStreamMetadata;
108108
int numAudioStreams = 0;
109109
int numVideoStreams = 0;
110110
// Note that this is the container-level duration, which is usually the max

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -345,10 +345,11 @@ std::string get_json_metadata(at::Tensor& decoder) {
345345
// serialize the metadata into a string std::stringstream ss;
346346
double durationSeconds = 0;
347347
if (maybeBestVideoStreamIndex.has_value() &&
348-
videoMetadata.streamMetadatas[*maybeBestVideoStreamIndex]
348+
videoMetadata.allStreamMetadata[*maybeBestVideoStreamIndex]
349349
.durationSeconds.has_value()) {
350-
durationSeconds = videoMetadata.streamMetadatas[*maybeBestVideoStreamIndex]
351-
.durationSeconds.value_or(0);
350+
durationSeconds =
351+
videoMetadata.allStreamMetadata[*maybeBestVideoStreamIndex]
352+
.durationSeconds.value_or(0);
352353
} else {
353354
// Fallback to container-level duration if stream duration is not found.
354355
durationSeconds = videoMetadata.durationSeconds.value_or(0);
@@ -361,7 +362,7 @@ std::string get_json_metadata(at::Tensor& decoder) {
361362

362363
if (maybeBestVideoStreamIndex.has_value()) {
363364
auto streamMetadata =
364-
videoMetadata.streamMetadatas[*maybeBestVideoStreamIndex];
365+
videoMetadata.allStreamMetadata[*maybeBestVideoStreamIndex];
365366
if (streamMetadata.numFramesFromScan.has_value()) {
366367
metadataMap["numFrames"] =
367368
std::to_string(*streamMetadata.numFramesFromScan);
@@ -425,7 +426,8 @@ std::string get_container_json_metadata(at::Tensor& decoder) {
425426
std::to_string(*containerMetadata.bestAudioStreamIndex);
426427
}
427428

428-
map["numStreams"] = std::to_string(containerMetadata.streamMetadatas.size());
429+
map["numStreams"] =
430+
std::to_string(containerMetadata.allStreamMetadata.size());
429431

430432
return mapToJson(map);
431433
}
@@ -434,13 +436,14 @@ std::string get_stream_json_metadata(
434436
at::Tensor& decoder,
435437
int64_t stream_index) {
436438
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
437-
auto streamMetadatas = videoDecoder->getContainerMetadata().streamMetadatas;
439+
auto allStreamMetadata =
440+
videoDecoder->getContainerMetadata().allStreamMetadata;
438441
if (stream_index < 0 ||
439-
stream_index >= static_cast<int64_t>(streamMetadatas.size())) {
442+
stream_index >= static_cast<int64_t>(allStreamMetadata.size())) {
440443
throw std::out_of_range(
441444
"stream_index out of bounds: " + std::to_string(stream_index));
442445
}
443-
auto streamMetadata = streamMetadatas[stream_index];
446+
auto streamMetadata = allStreamMetadata[stream_index];
444447

445448
std::map<std::string, std::string> map;
446449

test/decoders/VideoDecoderTest.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ TEST_P(VideoDecoderTest, ReturnsFpsAndDurationForVideoInMetadata) {
7272
#else
7373
EXPECT_NEAR(metadata.bitRate.value(), 324915, 1e-1);
7474
#endif
75-
EXPECT_EQ(metadata.streamMetadatas.size(), 6);
76-
const auto& videoStream = metadata.streamMetadatas[3];
75+
EXPECT_EQ(metadata.allStreamMetadata.size(), 6);
76+
const auto& videoStream = metadata.allStreamMetadata[3];
7777
EXPECT_EQ(videoStream.mediaType, AVMEDIA_TYPE_VIDEO);
7878
EXPECT_EQ(videoStream.codecName, "h264");
7979
EXPECT_NEAR(*videoStream.averageFps, 29.97f, 1e-1);
@@ -85,7 +85,7 @@ TEST_P(VideoDecoderTest, ReturnsFpsAndDurationForVideoInMetadata) {
8585
EXPECT_FALSE(videoStream.numFramesFromScan.has_value());
8686
decoder->scanFileAndUpdateMetadataAndIndex();
8787
metadata = decoder->getContainerMetadata();
88-
const auto& videoStream1 = metadata.streamMetadatas[3];
88+
const auto& videoStream1 = metadata.allStreamMetadata[3];
8989
EXPECT_EQ(*videoStream1.minPtsSecondsFromScan, 0);
9090
EXPECT_EQ(*videoStream1.maxPtsSecondsFromScan, 13.013);
9191
EXPECT_EQ(*videoStream1.numFramesFromScan, 390);
@@ -428,9 +428,9 @@ TEST_P(VideoDecoderTest, GetAudioMetadata) {
428428
VideoDecoder::ContainerMetadata metadata = decoder->getContainerMetadata();
429429
EXPECT_EQ(metadata.numAudioStreams, 1);
430430
EXPECT_EQ(metadata.numVideoStreams, 0);
431-
EXPECT_EQ(metadata.streamMetadatas.size(), 1);
431+
EXPECT_EQ(metadata.allStreamMetadata.size(), 1);
432432

433-
const auto& audioStream = metadata.streamMetadatas[0];
433+
const auto& audioStream = metadata.allStreamMetadata[0];
434434
EXPECT_EQ(audioStream.mediaType, AVMEDIA_TYPE_AUDIO);
435435
EXPECT_NEAR(*audioStream.durationSeconds, 13.25, 1e-1);
436436
}

0 commit comments

Comments
 (0)