Skip to content

Commit 8bee8d7

Browse files
author
Molly Xu
committed
refactor metadata fallback logic
1 parent 169484d commit 8bee8d7

File tree

11 files changed

+207
-227
lines changed

11 files changed

+207
-227
lines changed

src/torchcodec/_core/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ function(make_torchcodec_libraries
9696
Encoder.cpp
9797
ValidationUtils.cpp
9898
Transform.cpp
99+
Metadata.cpp
99100
)
100101

101102
if(ENABLE_CUDA)

src/torchcodec/_core/Metadata.cpp

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#include "Metadata.h"
8+
9+
namespace facebook::torchcodec {
10+
11+
std::optional<double> StreamMetadata::getDurationSeconds(
12+
SeekMode seekMode) const {
13+
switch (seekMode) {
14+
case SeekMode::custom_frame_mappings:
15+
case SeekMode::exact:
16+
// In exact mode, use the scanned content value
17+
if (endStreamPtsSecondsFromContent.has_value() &&
18+
beginStreamPtsSecondsFromContent.has_value()) {
19+
return endStreamPtsSecondsFromContent.value() -
20+
beginStreamPtsSecondsFromContent.value();
21+
}
22+
return std::nullopt;
23+
case SeekMode::approximate:
24+
if (durationSecondsFromHeader.has_value()) {
25+
return durationSecondsFromHeader.value();
26+
}
27+
if (numFramesFromHeader.has_value() && averageFpsFromHeader.has_value() &&
28+
averageFpsFromHeader.value() != 0.0) {
29+
return static_cast<double>(numFramesFromHeader.value()) /
30+
averageFpsFromHeader.value();
31+
}
32+
return std::nullopt;
33+
}
34+
return std::nullopt;
35+
}
36+
37+
double StreamMetadata::getBeginStreamSeconds(SeekMode seekMode) const {
38+
switch (seekMode) {
39+
case SeekMode::custom_frame_mappings:
40+
case SeekMode::exact:
41+
if (beginStreamPtsSecondsFromContent.has_value()) {
42+
return beginStreamPtsSecondsFromContent.value();
43+
}
44+
return 0.0;
45+
case SeekMode::approximate:
46+
return 0.0;
47+
}
48+
return 0.0;
49+
}
50+
51+
std::optional<double> StreamMetadata::getEndStreamSeconds(
52+
SeekMode seekMode) const {
53+
switch (seekMode) {
54+
case SeekMode::custom_frame_mappings:
55+
case SeekMode::exact:
56+
if (endStreamPtsSecondsFromContent.has_value()) {
57+
return endStreamPtsSecondsFromContent.value();
58+
}
59+
return getDurationSeconds(seekMode);
60+
case SeekMode::approximate:
61+
return getDurationSeconds(seekMode);
62+
}
63+
return std::nullopt;
64+
}
65+
66+
std::optional<int64_t> StreamMetadata::getNumFrames(SeekMode seekMode) const {
67+
switch (seekMode) {
68+
case SeekMode::custom_frame_mappings:
69+
case SeekMode::exact:
70+
if (numFramesFromContent.has_value()) {
71+
return numFramesFromContent.value();
72+
}
73+
return std::nullopt;
74+
case SeekMode::approximate: {
75+
if (numFramesFromHeader.has_value()) {
76+
return numFramesFromHeader.value();
77+
}
78+
if (averageFpsFromHeader.has_value() &&
79+
durationSecondsFromHeader.has_value()) {
80+
return static_cast<int64_t>(
81+
averageFpsFromHeader.value() * durationSecondsFromHeader.value());
82+
}
83+
return std::nullopt;
84+
}
85+
}
86+
return std::nullopt;
87+
}
88+
89+
std::optional<double> StreamMetadata::getAverageFps(SeekMode seekMode) const {
90+
switch (seekMode) {
91+
case SeekMode::custom_frame_mappings:
92+
case SeekMode::exact:
93+
if (getNumFrames(seekMode).has_value() &&
94+
beginStreamPtsSecondsFromContent.has_value() &&
95+
endStreamPtsSecondsFromContent.has_value() &&
96+
(beginStreamPtsSecondsFromContent.value() !=
97+
endStreamPtsSecondsFromContent.value())) {
98+
return static_cast<double>(
99+
getNumFrames(seekMode).value() /
100+
(endStreamPtsSecondsFromContent.value() -
101+
beginStreamPtsSecondsFromContent.value()));
102+
}
103+
return averageFpsFromHeader;
104+
case SeekMode::approximate:
105+
return averageFpsFromHeader;
106+
}
107+
return std::nullopt;
108+
}
109+
110+
} // namespace facebook::torchcodec

src/torchcodec/_core/Metadata.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#include <string>
1111
#include <vector>
1212

13+
#include "SeekMode.h"
14+
1315
extern "C" {
1416
#include <libavcodec/avcodec.h>
1517
#include <libavutil/avutil.h>
@@ -52,6 +54,13 @@ struct StreamMetadata {
5254
std::optional<int64_t> sampleRate;
5355
std::optional<int64_t> numChannels;
5456
std::optional<std::string> sampleFormat;
57+
58+
// Computed methods with fallback logic
59+
std::optional<double> getDurationSeconds(SeekMode seekMode) const;
60+
double getBeginStreamSeconds(SeekMode seekMode) const;
61+
std::optional<double> getEndStreamSeconds(SeekMode seekMode) const;
62+
std::optional<int64_t> getNumFrames(SeekMode seekMode) const;
63+
std::optional<double> getAverageFps(SeekMode seekMode) const;
5564
};
5665

5766
struct ContainerMetadata {

src/torchcodec/_core/SeekMode.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
namespace facebook::torchcodec {
10+
11+
enum class SeekMode { exact, approximate, custom_frame_mappings };
12+
13+
} // namespace facebook::torchcodec

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,10 @@ ContainerMetadata SingleStreamDecoder::getContainerMetadata() const {
367367
return containerMetadata_;
368368
}
369369

370+
SeekMode SingleStreamDecoder::getSeekMode() const {
371+
return seekMode_;
372+
}
373+
370374
torch::Tensor SingleStreamDecoder::getKeyFrameIndices() {
371375
validateActiveStream(AVMEDIA_TYPE_VIDEO);
372376
validateScannedAllStreams("getKeyFrameIndices");

src/torchcodec/_core/SingleStreamDecoder.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "DeviceInterface.h"
1717
#include "FFMPEGCommon.h"
1818
#include "Frame.h"
19+
#include "SeekMode.h"
1920
#include "StreamOptions.h"
2021
#include "Transform.h"
2122

@@ -30,8 +31,6 @@ class SingleStreamDecoder {
3031
// CONSTRUCTION API
3132
// --------------------------------------------------------------------------
3233

33-
enum class SeekMode { exact, approximate, custom_frame_mappings };
34-
3534
// Creates a SingleStreamDecoder from the video at videoFilePath.
3635
explicit SingleStreamDecoder(
3736
const std::string& videoFilePath,
@@ -60,6 +59,9 @@ class SingleStreamDecoder {
6059
// Returns the metadata for the container.
6160
ContainerMetadata getContainerMetadata() const;
6261

62+
// Returns the seek mode of this decoder.
63+
SeekMode getSeekMode() const;
64+
6365
// Returns the key frame indices as a tensor. The tensor is 1D and contains
6466
// int64 values, where each value is the frame index for a key frame.
6567
torch::Tensor getKeyFrameIndices();

src/torchcodec/_core/_metadata.py

Lines changed: 27 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,22 @@ class StreamMetadata:
3838
stream_index: int
3939
"""Index of the stream that this metadata refers to (int)."""
4040

41+
# Computed fields (computed in C++ with fallback logic)
42+
duration_seconds: Optional[float]
43+
"""Duration of the stream in seconds. Computed in C++ with fallback logic:
44+
tries to calculate from content if scan was performed, otherwise falls back
45+
to header values."""
46+
begin_stream_seconds: Optional[float]
47+
"""Beginning of the stream, in seconds. Computed in C++ with fallback logic."""
48+
4149
def __repr__(self):
4250
s = self.__class__.__name__ + ":\n"
4351
for field in dataclasses.fields(self):
4452
s += f"{SPACES}{field.name}: {getattr(self, field.name)}\n"
4553
return s
4654

4755

48-
@dataclass
56+
@dataclass(repr=False)
4957
class VideoStreamMetadata(StreamMetadata):
5058
"""Metadata of a single video stream."""
5159

@@ -87,103 +95,19 @@ class VideoStreamMetadata(StreamMetadata):
8795
is the ratio between the width and height of each pixel
8896
(``fractions.Fraction`` or None)."""
8997

90-
@property
91-
def duration_seconds(self) -> Optional[float]:
92-
"""Duration of the stream in seconds. We try to calculate the duration
93-
from the actual frames if a :term:`scan` was performed. Otherwise we
94-
fall back to ``duration_seconds_from_header``. If that value is also None,
95-
we instead calculate the duration from ``num_frames_from_header`` and
96-
``average_fps_from_header``.
97-
"""
98-
if (
99-
self.end_stream_seconds_from_content is not None
100-
and self.begin_stream_seconds_from_content is not None
101-
):
102-
return (
103-
self.end_stream_seconds_from_content
104-
- self.begin_stream_seconds_from_content
105-
)
106-
elif self.duration_seconds_from_header is not None:
107-
return self.duration_seconds_from_header
108-
elif (
109-
self.num_frames_from_header is not None
110-
and self.average_fps_from_header is not None
111-
):
112-
return self.num_frames_from_header / self.average_fps_from_header
113-
else:
114-
return None
115-
116-
@property
117-
def begin_stream_seconds(self) -> float:
118-
"""Beginning of the stream, in seconds (float). Conceptually, this
119-
corresponds to the first frame's :term:`pts`. If
120-
``begin_stream_seconds_from_content`` is not None, then it is returned.
121-
Otherwise, this value is 0.
122-
"""
123-
if self.begin_stream_seconds_from_content is None:
124-
return 0
125-
else:
126-
return self.begin_stream_seconds_from_content
127-
128-
@property
129-
def end_stream_seconds(self) -> Optional[float]:
130-
"""End of the stream, in seconds (float or None).
131-
Conceptually, this corresponds to last_frame.pts + last_frame.duration.
132-
If ``end_stream_seconds_from_content`` is not None, then that value is
133-
returned. Otherwise, returns ``duration_seconds``.
134-
"""
135-
if self.end_stream_seconds_from_content is None:
136-
return self.duration_seconds
137-
else:
138-
return self.end_stream_seconds_from_content
139-
140-
@property
141-
def num_frames(self) -> Optional[int]:
142-
"""Number of frames in the stream (int or None).
143-
This corresponds to ``num_frames_from_content`` if a :term:`scan` was made,
144-
otherwise it corresponds to ``num_frames_from_header``. If that value is also
145-
None, the number of frames is calculated from the duration and the average fps.
146-
"""
147-
if self.num_frames_from_content is not None:
148-
return self.num_frames_from_content
149-
elif self.num_frames_from_header is not None:
150-
return self.num_frames_from_header
151-
elif (
152-
self.average_fps_from_header is not None
153-
and self.duration_seconds_from_header is not None
154-
):
155-
return int(self.average_fps_from_header * self.duration_seconds_from_header)
156-
else:
157-
return None
158-
159-
@property
160-
def average_fps(self) -> Optional[float]:
161-
"""Average fps of the stream. If a :term:`scan` was perfomed, this is
162-
computed from the number of frames and the duration of the stream.
163-
Otherwise we fall back to ``average_fps_from_header``.
164-
"""
165-
if (
166-
self.end_stream_seconds_from_content is None
167-
or self.begin_stream_seconds_from_content is None
168-
or self.num_frames is None
169-
# Should never happen, but prevents ZeroDivisionError:
170-
or self.end_stream_seconds_from_content
171-
== self.begin_stream_seconds_from_content
172-
):
173-
return self.average_fps_from_header
174-
return self.num_frames / (
175-
self.end_stream_seconds_from_content
176-
- self.begin_stream_seconds_from_content
177-
)
178-
179-
def __repr__(self):
180-
s = super().__repr__()
181-
s += f"{SPACES}duration_seconds: {self.duration_seconds}\n"
182-
s += f"{SPACES}begin_stream_seconds: {self.begin_stream_seconds}\n"
183-
s += f"{SPACES}end_stream_seconds: {self.end_stream_seconds}\n"
184-
s += f"{SPACES}num_frames: {self.num_frames}\n"
185-
s += f"{SPACES}average_fps: {self.average_fps}\n"
186-
return s
98+
# Computed fields (computed in C++ with fallback logic)
99+
end_stream_seconds: Optional[float]
100+
"""End of the stream, in seconds (float or None).
101+
Conceptually, this corresponds to last_frame.pts + last_frame.duration.
102+
Computed in C++ with fallback logic."""
103+
num_frames: Optional[int]
104+
"""Number of frames in the stream (int or None).
105+
Computed in C++ with fallback logic: uses content if scan was performed,
106+
otherwise falls back to header values or calculates from duration and fps."""
107+
average_fps: Optional[float]
108+
"""Average fps of the stream (float or None).
109+
Computed in C++ with fallback logic: if scan was performed, computes from
110+
num_frames and duration, otherwise uses header value."""
187111

188112

189113
@dataclass
@@ -260,10 +184,12 @@ def get_container_metadata(decoder: torch.Tensor) -> ContainerMetadata:
260184
stream_dict = json.loads(_get_stream_json_metadata(decoder, stream_index))
261185
common_meta = dict(
262186
duration_seconds_from_header=stream_dict.get("durationSecondsFromHeader"),
187+
duration_seconds=stream_dict.get("durationSeconds"),
263188
bit_rate=stream_dict.get("bitRate"),
264189
begin_stream_seconds_from_header=stream_dict.get(
265190
"beginStreamSecondsFromHeader"
266191
),
192+
begin_stream_seconds=stream_dict.get("beginStreamSeconds"),
267193
codec=stream_dict.get("codec"),
268194
stream_index=stream_index,
269195
)
@@ -276,6 +202,9 @@ def get_container_metadata(decoder: torch.Tensor) -> ContainerMetadata:
276202
end_stream_seconds_from_content=stream_dict.get(
277203
"endStreamSecondsFromContent"
278204
),
205+
end_stream_seconds=stream_dict.get("endStreamSeconds"),
206+
num_frames=stream_dict.get("numFrames"),
207+
average_fps=stream_dict.get("averageFps"),
279208
width=stream_dict.get("width"),
280209
height=stream_dict.get("height"),
281210
num_frames_from_header=stream_dict.get("numFramesFromHeader"),

0 commit comments

Comments
 (0)