Skip to content

Commit abc9b10

Browse files
authored
Move "scanned" metadata as video-only and add begin_stream_from_header (#591)
1 parent 2e6f0ed commit abc9b10

File tree

9 files changed

+194
-137
lines changed

9 files changed

+194
-137
lines changed

src/torchcodec/decoders/_audio_decoder.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torchcodec.decoders import _core as core
1414
from torchcodec.decoders._decoder_utils import (
1515
create_decoder,
16-
get_and_validate_stream_metadata,
16+
ERROR_REPORTING_INSTRUCTIONS,
1717
)
1818

1919

@@ -57,15 +57,20 @@ def __init__(
5757
self._decoder, stream_index=stream_index, sample_rate=sample_rate
5858
)
5959

60-
(
61-
self.metadata,
62-
self.stream_index,
63-
self._begin_stream_seconds,
64-
self._end_stream_seconds,
65-
) = get_and_validate_stream_metadata(
66-
decoder=self._decoder, stream_index=stream_index, media_type="audio"
60+
container_metadata = core.get_container_metadata(self._decoder)
61+
self.stream_index = (
62+
container_metadata.best_audio_stream_index
63+
if stream_index is None
64+
else stream_index
6765
)
66+
if self.stream_index is None:
67+
raise ValueError(
68+
"The best audio stream is unknown and there is no specified stream. "
69+
+ ERROR_REPORTING_INSTRUCTIONS
70+
)
71+
self.metadata = container_metadata.streams[self.stream_index]
6872
assert isinstance(self.metadata, core.AudioStreamMetadata) # mypy
73+
6974
self._desired_sample_rate = (
7075
sample_rate if sample_rate is not None else self.metadata.sample_rate
7176
)
@@ -90,12 +95,6 @@ def get_samples_played_in_range(
9095
raise ValueError(
9196
f"Invalid start seconds: {start_seconds}. It must be less than or equal to stop seconds ({stop_seconds})."
9297
)
93-
if not self._begin_stream_seconds <= start_seconds < self._end_stream_seconds:
94-
raise ValueError(
95-
f"Invalid start seconds: {start_seconds}. "
96-
f"It must be greater than or equal to {self._begin_stream_seconds} "
97-
f"and less than or equal to {self._end_stream_seconds}."
98-
)
9998
frames, first_pts = core.get_frames_by_pts_in_range_audio(
10099
self._decoder,
101100
start_seconds=start_seconds,

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,10 @@ void VideoDecoder::initializeDecoder() {
147147
streamMetadata.durationSeconds =
148148
av_q2d(avStream->time_base) * avStream->duration;
149149
}
150+
if (avStream->start_time != AV_NOPTS_VALUE) {
151+
streamMetadata.beginStreamFromHeader =
152+
av_q2d(avStream->time_base) * avStream->start_time;
153+
}
150154

151155
if (avStream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO) {
152156
double fps = av_q2d(avStream->r_frame_rate);
@@ -944,8 +948,9 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
944948
TORCH_CHECK(
945949
frames.size() > 0 && firstFramePtsSeconds.has_value(),
946950
"No audio frames were decoded. ",
947-
"This should probably not happen. ",
948-
"Please report an issue on the TorchCodec repo.");
951+
"This is probably because start_seconds is too high? ",
952+
"Current value is ",
953+
startSeconds);
949954

950955
return AudioFramesOutput{torch::cat(frames, 1), *firstFramePtsSeconds};
951956
}

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class VideoDecoder {
5959
std::optional<AVCodecID> codecId;
6060
std::optional<std::string> codecName;
6161
std::optional<double> durationSeconds;
62+
std::optional<double> beginStreamFromHeader;
6263
std::optional<int64_t> numFrames;
6364
std::optional<int64_t> numKeyFrames;
6465
std::optional<double> averageFps;

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,10 @@ std::string get_stream_json_metadata(
473473
if (streamMetadata.numFrames.has_value()) {
474474
map["numFrames"] = std::to_string(*streamMetadata.numFrames);
475475
}
476+
if (streamMetadata.beginStreamFromHeader.has_value()) {
477+
map["beginStreamFromHeader"] =
478+
std::to_string(*streamMetadata.beginStreamFromHeader);
479+
}
476480
if (streamMetadata.minPtsSecondsFromScan.has_value()) {
477481
map["minPtsSecondsFromScan"] =
478482
std::to_string(*streamMetadata.minPtsSecondsFromScan);

src/torchcodec/decoders/_core/_metadata.py

Lines changed: 55 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -32,27 +32,59 @@ class StreamMetadata:
3232
duration_seconds_from_header: Optional[float]
3333
"""Duration of the stream, in seconds, obtained from the header (float or
3434
None). This could be inaccurate."""
35+
begin_stream_seconds_from_header: Optional[float]
36+
"""Beginning of the stream, in seconds, obtained from the header (float or
37+
None). Usually, this is equal to 0."""
3538
bit_rate: Optional[float]
3639
"""Bit rate of the stream, in seconds (float or None)."""
40+
codec: Optional[str]
41+
"""Codec (str or None)."""
42+
stream_index: int
43+
"""Index of the stream within the video (int)."""
44+
45+
def __repr__(self):
46+
s = self.__class__.__name__ + ":\n"
47+
for field in dataclasses.fields(self):
48+
s += f"{SPACES}{field.name}: {getattr(self, field.name)}\n"
49+
return s
50+
51+
52+
@dataclass
53+
class VideoStreamMetadata(StreamMetadata):
54+
"""Metadata of a single video stream."""
55+
3756
begin_stream_seconds_from_content: Optional[float]
3857
"""Beginning of the stream, in seconds (float or None).
39-
Conceptually, this corresponds to the first frame's :term:`pts`. It is
40-
computed as min(frame.pts) across all frames in the stream. Usually, this is
41-
equal to 0."""
58+
Conceptually, this corresponds to the first frame's :term:`pts`. It is only
59+
computed when a :term:`scan` is done as min(frame.pts) across all frames in
60+
the stream. Usually, this is equal to 0."""
4261
end_stream_seconds_from_content: Optional[float]
4362
"""End of the stream, in seconds (float or None).
4463
Conceptually, this corresponds to last_frame.pts + last_frame.duration. It
45-
is computed as max(frame.pts + frame.duration) across all frames in the
46-
stream. Note that no frame is played at this time value, so calling
47-
:meth:`~torchcodec.decoders.VideoDecoder.get_frame_played_at` with
48-
this value would result in an error. Retrieving the last frame is best done
49-
by simply indexing the :class:`~torchcodec.decoders.VideoDecoder`
50-
object with ``[-1]``.
64+
is only computed when a :term:`scan` is done as max(frame.pts +
65+
frame.duration) across all frames in the stream. Note that no frame is
66+
played at this time value, so calling
67+
:meth:`~torchcodec.decoders.VideoDecoder.get_frame_played_at` with this
68+
value would result in an error. Retrieving the last frame is best done by
69+
simply indexing the :class:`~torchcodec.decoders.VideoDecoder` object with
70+
``[-1]``.
5171
"""
52-
codec: Optional[str]
53-
"""Codec (str or None)."""
54-
stream_index: int
55-
"""Index of the stream within the video (int)."""
72+
width: Optional[int]
73+
"""Width of the frames (int or None)."""
74+
height: Optional[int]
75+
"""Height of the frames (int or None)."""
76+
num_frames_from_header: Optional[int]
77+
"""Number of frames, from the stream's metadata. This is potentially
78+
inaccurate. We recommend using the ``num_frames`` attribute instead.
79+
(int or None)."""
80+
num_frames_from_content: Optional[int]
81+
"""Number of frames computed by TorchCodec by scanning the stream's
82+
content (the scan doesn't involve decoding). This is more accurate
83+
than ``num_frames_from_header``. We recommend using the
84+
``num_frames`` attribute instead. (int or None)."""
85+
average_fps_from_header: Optional[float]
86+
"""Averate fps of the stream, obtained from the header (float or None).
87+
We recommend using the ``average_fps`` attribute instead."""
5688

5789
@property
5890
def duration_seconds(self) -> Optional[float]:
@@ -94,36 +126,6 @@ def end_stream_seconds(self) -> Optional[float]:
94126
else:
95127
return self.end_stream_seconds_from_content
96128

97-
def __repr__(self):
98-
# Overridden because properites are not printed by default.
99-
s = self.__class__.__name__ + ":\n"
100-
s += f"{SPACES}duration_seconds: {self.duration_seconds}\n"
101-
for field in dataclasses.fields(self):
102-
s += f"{SPACES}{field.name}: {getattr(self, field.name)}\n"
103-
return s
104-
105-
106-
@dataclass
107-
class VideoStreamMetadata(StreamMetadata):
108-
"""Metadata of a single video stream."""
109-
110-
width: Optional[int]
111-
"""Width of the frames (int or None)."""
112-
height: Optional[int]
113-
"""Height of the frames (int or None)."""
114-
num_frames_from_header: Optional[int]
115-
"""Number of frames, from the stream's metadata. This is potentially
116-
inaccurate. We recommend using the ``num_frames`` attribute instead.
117-
(int or None)."""
118-
num_frames_from_content: Optional[int]
119-
"""Number of frames computed by TorchCodec by scanning the stream's
120-
content (the scan doesn't involve decoding). This is more accurate
121-
than ``num_frames_from_header``. We recommend using the
122-
``num_frames`` attribute instead. (int or None)."""
123-
average_fps_from_header: Optional[float]
124-
"""Averate fps of the stream, obtained from the header (float or None).
125-
We recommend using the ``average_fps`` attribute instead."""
126-
127129
@property
128130
def num_frames(self) -> Optional[int]:
129131
"""Number of frames in the stream. This corresponds to
@@ -154,6 +156,9 @@ def average_fps(self) -> Optional[float]:
154156

155157
def __repr__(self):
156158
s = super().__repr__()
159+
s += f"{SPACES}duration_seconds: {self.duration_seconds}\n"
160+
s += f"{SPACES}begin_stream_seconds: {self.begin_stream_seconds}\n"
161+
s += f"{SPACES}end_stream_seconds: {self.end_stream_seconds}\n"
157162
s += f"{SPACES}num_frames: {self.num_frames}\n"
158163
s += f"{SPACES}average_fps: {self.average_fps}\n"
159164
return s
@@ -224,14 +229,19 @@ def get_container_metadata(decoder: torch.Tensor) -> ContainerMetadata:
224229
common_meta = dict(
225230
duration_seconds_from_header=stream_dict.get("durationSeconds"),
226231
bit_rate=stream_dict.get("bitRate"),
227-
begin_stream_seconds_from_content=stream_dict.get("minPtsSecondsFromScan"),
228-
end_stream_seconds_from_content=stream_dict.get("maxPtsSecondsFromScan"),
232+
begin_stream_seconds_from_header=stream_dict.get("beginStreamFromHeader"),
229233
codec=stream_dict.get("codec"),
230234
stream_index=stream_index,
231235
)
232236
if stream_dict["mediaType"] == "video":
233237
streams_metadata.append(
234238
VideoStreamMetadata(
239+
begin_stream_seconds_from_content=stream_dict.get(
240+
"minPtsSecondsFromScan"
241+
),
242+
end_stream_seconds_from_content=stream_dict.get(
243+
"maxPtsSecondsFromScan"
244+
),
235245
width=stream_dict.get("width"),
236246
height=stream_dict.get("height"),
237247
num_frames_from_header=stream_dict.get("numFrames"),

src/torchcodec/decoders/_decoder_utils.py

Lines changed: 1 addition & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from pathlib import Path
88

9-
from typing import Optional, Tuple, Union
9+
from typing import Union
1010

1111
from torch import Tensor
1212
from torchcodec.decoders import _core as core
@@ -33,55 +33,3 @@ def create_decoder(
3333
f"Unknown source type: {type(source)}. "
3434
"Supported types are str, Path, bytes and Tensor."
3535
)
36-
37-
38-
def get_and_validate_stream_metadata(
39-
*,
40-
decoder: Tensor,
41-
stream_index: Optional[int] = None,
42-
media_type: str,
43-
) -> Tuple[core._metadata.StreamMetadata, int, float, float]:
44-
45-
if media_type not in ("video", "audio"):
46-
raise ValueError(f"Bad {media_type = }, should be audio or video")
47-
48-
container_metadata = core.get_container_metadata(decoder)
49-
50-
if stream_index is None:
51-
best_stream_index = (
52-
container_metadata.best_video_stream_index
53-
if media_type == "video"
54-
else container_metadata.best_audio_stream_index
55-
)
56-
if best_stream_index is None:
57-
raise ValueError(
58-
f"The best {media_type} stream is unknown and there is no specified stream. "
59-
+ ERROR_REPORTING_INSTRUCTIONS
60-
)
61-
stream_index = best_stream_index
62-
63-
# This should be logically true because of the above conditions, but type checker
64-
# is not clever enough.
65-
assert stream_index is not None
66-
67-
metadata = container_metadata.streams[stream_index]
68-
69-
if metadata.begin_stream_seconds is None:
70-
raise ValueError(
71-
"The minimum pts value in seconds is unknown. "
72-
+ ERROR_REPORTING_INSTRUCTIONS
73-
)
74-
begin_stream_seconds = metadata.begin_stream_seconds
75-
76-
if metadata.end_stream_seconds is None:
77-
raise ValueError(
78-
"The maximum pts value in seconds is unknown. "
79-
+ ERROR_REPORTING_INSTRUCTIONS
80-
)
81-
end_stream_seconds = metadata.end_stream_seconds
82-
return (
83-
metadata,
84-
stream_index,
85-
begin_stream_seconds,
86-
end_stream_seconds,
87-
)

src/torchcodec/decoders/_video_decoder.py

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import numbers
88
from pathlib import Path
9-
from typing import Literal, Optional, Union
9+
from typing import Literal, Optional, Tuple, Union
1010

1111
from torch import device, Tensor
1212

@@ -15,7 +15,6 @@
1515
from torchcodec.decoders._decoder_utils import (
1616
create_decoder,
1717
ERROR_REPORTING_INSTRUCTIONS,
18-
get_and_validate_stream_metadata,
1918
)
2019

2120

@@ -108,18 +107,11 @@ def __init__(
108107
self.stream_index,
109108
self._begin_stream_seconds,
110109
self._end_stream_seconds,
111-
) = get_and_validate_stream_metadata(
112-
decoder=self._decoder, stream_index=stream_index, media_type="video"
110+
self._num_frames,
111+
) = _get_and_validate_stream_metadata(
112+
decoder=self._decoder, stream_index=stream_index
113113
)
114114

115-
assert isinstance(self.metadata, core.VideoStreamMetadata) # mypy
116-
117-
if self.metadata.num_frames is None:
118-
raise ValueError(
119-
"The number of frames is unknown. " + ERROR_REPORTING_INSTRUCTIONS
120-
)
121-
self._num_frames = self.metadata.num_frames
122-
123115
def __len__(self) -> int:
124116
return self._num_frames
125117

@@ -338,3 +330,50 @@ def get_frames_played_in_range(
338330
stop_seconds=stop_seconds,
339331
)
340332
return FrameBatch(*frames)
333+
334+
335+
def _get_and_validate_stream_metadata(
336+
*,
337+
decoder: Tensor,
338+
stream_index: Optional[int] = None,
339+
) -> Tuple[core._metadata.VideoStreamMetadata, int, float, float, int]:
340+
341+
container_metadata = core.get_container_metadata(decoder)
342+
343+
if stream_index is None:
344+
if (stream_index := container_metadata.best_video_stream_index) is None:
345+
raise ValueError(
346+
"The best video stream is unknown and there is no specified stream. "
347+
+ ERROR_REPORTING_INSTRUCTIONS
348+
)
349+
350+
metadata = container_metadata.streams[stream_index]
351+
assert isinstance(metadata, core._metadata.VideoStreamMetadata) # mypy
352+
353+
if metadata.begin_stream_seconds is None:
354+
raise ValueError(
355+
"The minimum pts value in seconds is unknown. "
356+
+ ERROR_REPORTING_INSTRUCTIONS
357+
)
358+
begin_stream_seconds = metadata.begin_stream_seconds
359+
360+
if metadata.end_stream_seconds is None:
361+
raise ValueError(
362+
"The maximum pts value in seconds is unknown. "
363+
+ ERROR_REPORTING_INSTRUCTIONS
364+
)
365+
end_stream_seconds = metadata.end_stream_seconds
366+
367+
if metadata.num_frames is None:
368+
raise ValueError(
369+
"The number of frames is unknown. " + ERROR_REPORTING_INSTRUCTIONS
370+
)
371+
num_frames = metadata.num_frames
372+
373+
return (
374+
metadata,
375+
stream_index,
376+
begin_stream_seconds,
377+
end_stream_seconds,
378+
num_frames,
379+
)

0 commit comments

Comments
 (0)