Skip to content

Commit d3bdfea

Browse files
committed
More stuff
1 parent 6383bca commit d3bdfea

File tree

10 files changed

+197
-145
lines changed

10 files changed

+197
-145
lines changed

src/torchcodec/decoders/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from ._core import VideoStreamMetadata
7+
from ._core import AudioStreamMetadata, VideoStreamMetadata
88
from ._video_decoder import VideoDecoder # noqa
99

1010
SimpleVideoDecoder = VideoDecoder

src/torchcodec/decoders/_audio_decoder.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,14 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from pathlib import Path
8-
from typing import Literal, Optional, Union
8+
from typing import Optional, Union
99

1010
from torch import Tensor
1111

1212
from torchcodec.decoders import _core as core
1313
from torchcodec.decoders._decoder_utils import (
1414
create_decoder,
1515
get_and_validate_stream_metadata,
16-
validate_seek_mode,
1716
)
1817

1918

@@ -25,17 +24,14 @@ def __init__(
2524
source: Union[str, Path, bytes, Tensor],
2625
*,
2726
stream_index: Optional[int] = None,
28-
seek_mode: Literal["exact", "approximate"] = "exact",
2927
):
30-
validate_seek_mode(seek_mode)
31-
self._decoder = create_decoder(source=source, seek_mode=seek_mode)
28+
self._decoder = create_decoder(source=source, seek_mode="approximate")
3229

3330
core.add_audio_stream(self._decoder, stream_index=stream_index)
3431

3532
(
3633
self.metadata,
3734
self.stream_index,
38-
self._num_frames,
3935
self._begin_stream_seconds,
4036
self._end_stream_seconds,
4137
) = get_and_validate_stream_metadata(

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -466,30 +466,6 @@ void VideoDecoder::addStream(
466466
.value_or(avCodec));
467467
}
468468

469-
// TODO_FRAME_SIZE_APPROXIMATE_MODE
470-
// For audio, we raise if seek_mode="approximate" and if the number of
471-
// samples per frame is unknown (frame_size field of codec params). But that's
472-
// quite limitting. Ultimately, the most common type of call will be to decode
473-
// an entire file from start to end (possibly with some offsets for start and
474-
// end). And for that, we shouldn't [need to] force the user to scan, because
475-
// all this entails is a single call to seek(start) (if at all) and then just
476-
// a bunch of consecutive calls to getNextFrame(). Maybe there should be a
477-
// third seek mode for audio, e.g. seek_mode="contiguous" where we don't scan,
478-
// and only allow calls to getFramesPlayedAt().
479-
StreamMetadata& streamMetadata =
480-
containerMetadata_.allStreamMetadata[activeStreamIndex_];
481-
if (seekMode_ == SeekMode::approximate &&
482-
!streamMetadata.averageFps.has_value()) {
483-
std::string errMsg = "Seek mode is approximate, but stream " +
484-
std::to_string(activeStreamIndex_) + "does not have ";
485-
if (mediaType == AVMEDIA_TYPE_VIDEO) {
486-
errMsg += "an average fps in its metadata.";
487-
} else {
488-
errMsg += "a constant number of samples per frame.";
489-
}
490-
throw std::runtime_error(errMsg);
491-
}
492-
493469
AVCodecContext* codecContext = avcodec_alloc_context3(avCodec);
494470
TORCH_CHECK(codecContext != nullptr);
495471
codecContext->thread_count =
@@ -565,13 +541,12 @@ void VideoDecoder::addVideoStream(
565541
}
566542

567543
void VideoDecoder::addAudioStream(int streamIndex) {
544+
TORCH_CHECK(
545+
seekMode_ == SeekMode::approximate,
546+
"seek_mode must be 'approximate' for audio streams.");
547+
568548
addStream(streamIndex, AVMEDIA_TYPE_AUDIO);
569549

570-
// See TODO_FRAME_SIZE_BATCH_TENSOR_ALLOCATION
571-
auto& streamInfo = streamInfos_[activeStreamIndex_];
572-
TORCH_CHECK(
573-
streamInfo.codecContext->frame_size > 0,
574-
"No support for variable framerate yet.");
575550
containerMetadata_.allStreamMetadata[activeStreamIndex_].sampleRate =
576551
streamInfo.codecContext->sample_rate;
577552
}

src/torchcodec/decoders/_core/_metadata.py

Lines changed: 62 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
)
2020

2121

22+
SPACES = " "
23+
24+
2225
# TODO-audio: docs below are mostly for video streams, we should edit them and /
2326
# or make sure they're OK for audio streams as well. Not sure how to best handle
2427
# docs for such class hierarchy.
@@ -29,15 +32,6 @@ class StreamMetadata:
2932
None). This could be inaccurate."""
3033
bit_rate: Optional[float]
3134
"""Bit rate of the stream, in seconds (float or None)."""
32-
num_frames_from_header: Optional[int]
33-
"""Number of frames, from the stream's metadata. This is potentially
34-
inaccurate. We recommend using the ``num_frames`` attribute instead.
35-
(int or None)."""
36-
num_frames_from_content: Optional[int]
37-
"""Number of frames computed by TorchCodec by scanning the stream's
38-
content (the scan doesn't involve decoding). This is more accurate
39-
than ``num_frames_from_header``. We recommend using the
40-
``num_frames`` attribute instead. (int or None)."""
4135
begin_stream_seconds_from_content: Optional[float]
4236
"""Beginning of the stream, in seconds (float or None).
4337
Conceptually, this corresponds to the first frame's :term:`pts`. It is
@@ -55,23 +49,9 @@ class StreamMetadata:
5549
"""
5650
codec: Optional[str]
5751
"""Codec (str or None)."""
58-
average_fps_from_header: Optional[float]
59-
"""Averate fps of the stream, obtained from the header (float or None).
60-
We recommend using the ``average_fps`` attribute instead."""
6152
stream_index: int
6253
"""Index of the stream within the video (int)."""
6354

64-
@property
65-
def num_frames(self) -> Optional[int]:
66-
"""Number of frames in the stream. This corresponds to
67-
``num_frames_from_content`` if a :term:`scan` was made, otherwise it
68-
corresponds to ``num_frames_from_header``.
69-
"""
70-
if self.num_frames_from_content is not None:
71-
return self.num_frames_from_content
72-
else:
73-
return self.num_frames_from_header
74-
7555
@property
7656
def duration_seconds(self) -> Optional[float]:
7757
"""Duration of the stream in seconds. We try to calculate the duration
@@ -88,23 +68,6 @@ def duration_seconds(self) -> Optional[float]:
8868
- self.begin_stream_seconds_from_content
8969
)
9070

91-
@property
92-
def average_fps(self) -> Optional[float]:
93-
"""Average fps of the stream. If a :term:`scan` was perfomed, this is
94-
computed from the number of frames and the duration of the stream.
95-
Otherwise we fall back to ``average_fps_from_header``.
96-
"""
97-
if (
98-
self.end_stream_seconds_from_content is None
99-
or self.begin_stream_seconds_from_content is None
100-
or self.num_frames is None
101-
):
102-
return self.average_fps_from_header
103-
return self.num_frames / (
104-
self.end_stream_seconds_from_content
105-
- self.begin_stream_seconds_from_content
106-
)
107-
10871
@property
10972
def begin_stream_seconds(self) -> float:
11073
"""Beginning of the stream, in seconds (float). Conceptually, this
@@ -132,12 +95,9 @@ def end_stream_seconds(self) -> Optional[float]:
13295
def __repr__(self):
13396
# Overridden because properites are not printed by default.
13497
s = self.__class__.__name__ + ":\n"
135-
spaces = " "
136-
s += f"{spaces}num_frames: {self.num_frames}\n"
137-
s += f"{spaces}duration_seconds: {self.duration_seconds}\n"
138-
s += f"{spaces}average_fps: {self.average_fps}\n"
98+
s += f"{SPACES}duration_seconds: {self.duration_seconds}\n"
13999
for field in dataclasses.fields(self):
140-
s += f"{spaces}{field.name}: {getattr(self, field.name)}\n"
100+
s += f"{SPACES}{field.name}: {getattr(self, field.name)}\n"
141101
return s
142102

143103

@@ -149,17 +109,58 @@ class VideoStreamMetadata(StreamMetadata):
149109
"""Width of the frames (int or None)."""
150110
height: Optional[int]
151111
"""Height of the frames (int or None)."""
112+
num_frames_from_header: Optional[int]
113+
"""Number of frames, from the stream's metadata. This is potentially
114+
inaccurate. We recommend using the ``num_frames`` attribute instead.
115+
(int or None)."""
116+
num_frames_from_content: Optional[int]
117+
"""Number of frames computed by TorchCodec by scanning the stream's
118+
content (the scan doesn't involve decoding). This is more accurate
119+
than ``num_frames_from_header``. We recommend using the
120+
``num_frames`` attribute instead. (int or None)."""
121+
average_fps_from_header: Optional[float]
122+
"""Averate fps of the stream, obtained from the header (float or None).
123+
We recommend using the ``average_fps`` attribute instead."""
124+
125+
@property
126+
def num_frames(self) -> Optional[int]:
127+
"""Number of frames in the stream. This corresponds to
128+
``num_frames_from_content`` if a :term:`scan` was made, otherwise it
129+
corresponds to ``num_frames_from_header``.
130+
"""
131+
if self.num_frames_from_content is not None:
132+
return self.num_frames_from_content
133+
else:
134+
return self.num_frames_from_header
135+
136+
@property
137+
def average_fps(self) -> Optional[float]:
138+
"""Average fps of the stream. If a :term:`scan` was perfomed, this is
139+
computed from the number of frames and the duration of the stream.
140+
Otherwise we fall back to ``average_fps_from_header``.
141+
"""
142+
if (
143+
self.end_stream_seconds_from_content is None
144+
or self.begin_stream_seconds_from_content is None
145+
or self.num_frames is None
146+
):
147+
return self.average_fps_from_header
148+
return self.num_frames / (
149+
self.end_stream_seconds_from_content
150+
- self.begin_stream_seconds_from_content
151+
)
152152

153153
def __repr__(self):
154-
return super().__repr__()
154+
s = super().__repr__()
155+
s += f"{SPACES}num_frames: {self.num_frames}\n"
156+
s += f"{SPACES}average_fps: {self.average_fps}\n"
157+
return s
155158

156159

157160
@dataclass
158161
class AudioStreamMetadata(StreamMetadata):
159162
"""Metadata of a single audio stream."""
160163

161-
# TODO-AUDIO do we expose the notion of frame here, like in fps? It's technically
162-
# valid, but potentially is an FFmpeg-specific concept for audio
163164
# TODO-AUDIO Need sample rate and format and num_channels
164165
sample_rate: Optional[int]
165166

@@ -192,6 +193,14 @@ def best_video_stream(self) -> VideoStreamMetadata:
192193
assert isinstance(metadata, VideoStreamMetadata) # mypy <3
193194
return metadata
194195

196+
@property
197+
def best_audio_stream(self) -> AudioStreamMetadata:
198+
if self.best_audio_stream_index is None:
199+
raise ValueError("The best audio stream is unknown.")
200+
metadata = self.streams[self.best_audio_stream_index]
201+
assert isinstance(metadata, AudioStreamMetadata) # mypy <3
202+
return metadata
203+
195204

196205
def get_container_metadata(decoder: torch.Tensor) -> ContainerMetadata:
197206
"""Return container metadata from a decoder.
@@ -207,19 +216,19 @@ def get_container_metadata(decoder: torch.Tensor) -> ContainerMetadata:
207216
common_meta = dict(
208217
duration_seconds_from_header=stream_dict.get("durationSeconds"),
209218
bit_rate=stream_dict.get("bitRate"),
210-
num_frames_from_header=stream_dict.get("numFrames"),
211-
num_frames_from_content=stream_dict.get("numFramesFromScan"),
212219
begin_stream_seconds_from_content=stream_dict.get("minPtsSecondsFromScan"),
213220
end_stream_seconds_from_content=stream_dict.get("maxPtsSecondsFromScan"),
214221
codec=stream_dict.get("codec"),
215-
average_fps_from_header=stream_dict.get("averageFps"),
216222
stream_index=stream_index,
217223
)
218224
if stream_dict["mediaType"] == "video":
219225
streams_metadata.append(
220226
VideoStreamMetadata(
221227
width=stream_dict.get("width"),
222228
height=stream_dict.get("height"),
229+
num_frames_from_header=stream_dict.get("numFrames"),
230+
num_frames_from_content=stream_dict.get("numFramesFromScan"),
231+
average_fps_from_header=stream_dict.get("averageFps"),
223232
**common_meta,
224233
)
225234
)
@@ -232,9 +241,8 @@ def get_container_metadata(decoder: torch.Tensor) -> ContainerMetadata:
232241
)
233242
else:
234243
# This is neither a video nor audio stream. Could be e.g. subtitles.
235-
# We still need to add an entry to streams_metadata to keep its
236-
# length consistent with the number of streams, so we add a dummy
237-
# entry.
244+
# We still need to add a dummy entry so that len(streams_metadata)
245+
# is consistent with the number of streams.
238246
streams_metadata.append(StreamMetadata(**common_meta))
239247

240248
return ContainerMetadata(

src/torchcodec/decoders/_decoder_utils.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,6 @@
1717
"""
1818

1919

20-
def validate_seek_mode(seek_mode: str) -> None:
21-
allowed_seek_modes = ("exact", "approximate")
22-
if seek_mode not in allowed_seek_modes:
23-
raise ValueError(
24-
f"Invalid seek mode ({seek_mode}). "
25-
f"Supported values are {', '.join(allowed_seek_modes)}."
26-
)
27-
28-
2920
def create_decoder(
3021
*, source: Union[str, Path, bytes, Tensor], seek_mode: str
3122
) -> Tensor:
@@ -49,7 +40,7 @@ def get_and_validate_stream_metadata(
4940
decoder: Tensor,
5041
stream_index: Optional[int] = None,
5142
media_type: str,
52-
) -> Tuple[core.VideoStreamMetadata, int]:
43+
) -> Tuple[core.VideoStreamMetadata, int, float, float]:
5344

5445
if media_type not in ("video", "audio"):
5546
raise ValueError(f"Bad {media_type = }, should be audio or video")
@@ -75,12 +66,6 @@ def get_and_validate_stream_metadata(
7566

7667
metadata = container_metadata.streams[stream_index]
7768

78-
if metadata.num_frames is None:
79-
raise ValueError(
80-
"The number of frames is unknown. " + ERROR_REPORTING_INSTRUCTIONS
81-
)
82-
num_frames = metadata.num_frames
83-
8469
if metadata.begin_stream_seconds is None:
8570
raise ValueError(
8671
"The minimum pts value in seconds is unknown. "
@@ -97,7 +82,6 @@ def get_and_validate_stream_metadata(
9782
return (
9883
metadata,
9984
stream_index,
100-
num_frames,
10185
begin_stream_seconds,
10286
end_stream_seconds,
10387
)

src/torchcodec/decoders/_video_decoder.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from torchcodec.decoders import _core as core
1515
from torchcodec.decoders._decoder_utils import (
1616
create_decoder,
17+
ERROR_REPORTING_INSTRUCTIONS,
1718
get_and_validate_stream_metadata,
18-
validate_seek_mode,
1919
)
2020

2121

@@ -76,7 +76,13 @@ def __init__(
7676
device: Optional[Union[str, device]] = "cpu",
7777
seek_mode: Literal["exact", "approximate"] = "exact",
7878
):
79-
validate_seek_mode(seek_mode)
79+
allowed_seek_modes = ("exact", "approximate")
80+
if seek_mode not in allowed_seek_modes:
81+
raise ValueError(
82+
f"Invalid seek mode ({seek_mode}). "
83+
f"Supported values are {', '.join(allowed_seek_modes)}."
84+
)
85+
8086
self._decoder = create_decoder(source=source, seek_mode=seek_mode)
8187

8288
allowed_dimension_orders = ("NCHW", "NHWC")
@@ -100,13 +106,18 @@ def __init__(
100106
(
101107
self.metadata,
102108
self.stream_index,
103-
self._num_frames,
104109
self._begin_stream_seconds,
105110
self._end_stream_seconds,
106111
) = get_and_validate_stream_metadata(
107112
decoder=self._decoder, stream_index=stream_index, media_type="video"
108113
)
109114

115+
if self.metadata.num_frames is None:
116+
raise ValueError(
117+
"The number of frames is unknown. " + ERROR_REPORTING_INSTRUCTIONS
118+
)
119+
self._num_frames = self.metadata.num_frames
120+
110121
def __len__(self) -> int:
111122
return self._num_frames
112123

0 commit comments

Comments
 (0)