Skip to content

Commit 5112494

Browse files
committed
Looots of stuff
1 parent f5ff8f0 commit 5112494

File tree

6 files changed

+188
-307
lines changed

6 files changed

+188
-307
lines changed

src/torchcodec/decoders/_audio_decoder.py

Lines changed: 17 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -5,136 +5,39 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from pathlib import Path
8-
from typing import Literal, Optional, Tuple, Union
8+
from typing import Literal, Optional, Union
99

1010
from torch import Tensor
1111

1212
from torchcodec.decoders import _core as core
13-
14-
_ERROR_REPORTING_INSTRUCTIONS = """
15-
This should never happen. Please report an issue following the steps in
16-
https://github.com/pytorch/torchcodec/issues/new?assignees=&labels=&projects=&template=bug-report.yml.
17-
"""
13+
from torchcodec.decoders._decoder_utils import (
14+
create_decoder,
15+
get_and_validate_stream_metadata,
16+
validate_seek_mode,
17+
)
1818

1919

2020
class AudioDecoder:
21-
"""A single-stream audio decoder.
22-
23-
TODO docs
24-
"""
21+
"""TODO-audio docs"""
2522

2623
def __init__(
2724
self,
2825
source: Union[str, Path, bytes, Tensor],
2926
*,
30-
sample_rate: Optional[int] = None,
3127
stream_index: Optional[int] = None,
3228
seek_mode: Literal["exact", "approximate"] = "exact",
3329
):
34-
if sample_rate is not None:
35-
raise ValueError("TODO implement this")
36-
37-
# TODO unify validation with VideoDecoder?
38-
allowed_seek_modes = ("exact", "approximate")
39-
if seek_mode not in allowed_seek_modes:
40-
raise ValueError(
41-
f"Invalid seek mode ({seek_mode}). "
42-
f"Supported values are {', '.join(allowed_seek_modes)}."
43-
)
44-
45-
if isinstance(source, str):
46-
self._decoder = core.create_from_file(source, seek_mode)
47-
elif isinstance(source, Path):
48-
self._decoder = core.create_from_file(str(source), seek_mode)
49-
elif isinstance(source, bytes):
50-
self._decoder = core.create_from_bytes(source, seek_mode)
51-
elif isinstance(source, Tensor):
52-
self._decoder = core.create_from_tensor(source, seek_mode)
53-
else:
54-
raise TypeError(
55-
f"Unknown source type: {type(source)}. "
56-
"Supported types are str, Path, bytes and Tensor."
57-
)
30+
validate_seek_mode(seek_mode)
31+
self._decoder = create_decoder(source=source, seek_mode=seek_mode)
5832

5933
core.add_audio_stream(self._decoder, stream_index=stream_index)
6034

61-
self.metadata, self.stream_index = _get_and_validate_stream_metadata(
62-
self._decoder, stream_index
63-
)
64-
65-
# if self.metadata.num_frames is None:
66-
# raise ValueError(
67-
# "The number of frames is unknown. " + _ERROR_REPORTING_INSTRUCTIONS
68-
# )
69-
# self._num_frames = self.metadata.num_frames
70-
71-
# if self.metadata.begin_stream_seconds is None:
72-
# raise ValueError(
73-
# "The minimum pts value in seconds is unknown. "
74-
# + _ERROR_REPORTING_INSTRUCTIONS
75-
# )
76-
# self._begin_stream_seconds = self.metadata.begin_stream_seconds
77-
78-
# if self.metadata.end_stream_seconds is None:
79-
# raise ValueError(
80-
# "The maximum pts value in seconds is unknown. "
81-
# + _ERROR_REPORTING_INSTRUCTIONS
82-
# )
83-
# self._end_stream_seconds = self.metadata.end_stream_seconds
84-
85-
# TODO we need to have a default for stop_seconds.
86-
def get_samples_played_in_range(
87-
self, start_seconds: float, stop_seconds: float
88-
) -> Tensor:
89-
"""
90-
TODO DOCS
91-
"""
92-
# if not start_seconds <= stop_seconds:
93-
# raise ValueError(
94-
# f"Invalid start seconds: {start_seconds}. It must be less than or equal to stop seconds ({stop_seconds})."
95-
# )
96-
# if not self._begin_stream_seconds <= start_seconds < self._end_stream_seconds:
97-
# raise ValueError(
98-
# f"Invalid start seconds: {start_seconds}. "
99-
# f"It must be greater than or equal to {self._begin_stream_seconds} "
100-
# f"and less than or equal to {self._end_stream_seconds}."
101-
# )
102-
# if not stop_seconds <= self._end_stream_seconds:
103-
# raise ValueError(
104-
# f"Invalid stop seconds: {stop_seconds}. "
105-
# f"It must be less than or equal to {self._end_stream_seconds}."
106-
# )
107-
108-
frames, *_ = core.get_frames_by_pts_in_range(
109-
self._decoder,
110-
start_seconds=start_seconds,
111-
stop_seconds=stop_seconds,
35+
(
36+
self.metadata,
37+
self.stream_index,
38+
self._num_frames,
39+
self._begin_stream_seconds,
40+
self._end_stream_seconds,
41+
) = get_and_validate_stream_metadata(
42+
decoder=self._decoder, stream_index=stream_index, media_type="audio"
11243
)
113-
# TODO need to return view on this to account for samples instead of
114-
# frames
115-
return frames
116-
117-
118-
def _get_and_validate_stream_metadata(
119-
decoder: Tensor,
120-
stream_index: Optional[int] = None,
121-
) -> Tuple[core.AudioStreamMetadata, int]:
122-
123-
# TODO should this still be called `get_video_metadata`?
124-
container_metadata = core.get_video_metadata(decoder)
125-
126-
if stream_index is None:
127-
best_stream_index = container_metadata.best_audio_stream_index
128-
if best_stream_index is None:
129-
raise ValueError(
130-
"The best audio stream is unknown and there is no specified stream. "
131-
+ _ERROR_REPORTING_INSTRUCTIONS
132-
)
133-
stream_index = best_stream_index
134-
135-
# This should be logically true because of the above conditions, but type checker
136-
# is not clever enough.
137-
assert stream_index is not None
138-
139-
stream_metadata = container_metadata.streams[stream_index]
140-
return (stream_metadata, stream_index)

src/torchcodec/decoders/_core/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77

88
from ._metadata import (
99
AudioStreamMetadata,
10-
get_video_metadata,
11-
get_video_metadata_from_header,
12-
VideoMetadata,
10+
ContainerMetadata,
11+
get_container_metadata,
12+
get_container_metadata_from_header,
1313
VideoStreamMetadata,
1414
)
1515
from .video_decoder_ops import (

src/torchcodec/decoders/_core/_metadata.py

Lines changed: 42 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,9 @@
1919
)
2020

2121

22+
# TODO-audio: docs below are mostly for video streams.
2223
@dataclass
23-
class VideoStreamMetadata:
24-
"""Metadata of a single video stream."""
25-
24+
class StreamMetadata:
2625
duration_seconds_from_header: Optional[float]
2726
"""Duration of the stream, in seconds, obtained from the header (float or
2827
None). This could be inaccurate."""
@@ -54,10 +53,6 @@ class VideoStreamMetadata:
5453
"""
5554
codec: Optional[str]
5655
"""Codec (str or None)."""
57-
width: Optional[int]
58-
"""Width of the frames (int or None)."""
59-
height: Optional[int]
60-
"""Height of the frames (int or None)."""
6156
average_fps_from_header: Optional[float]
6257
"""Averate fps of the stream, obtained from the header (float or None).
6358
We recommend using the ``average_fps`` attribute instead."""
@@ -145,109 +140,37 @@ def __repr__(self):
145140

146141

147142
@dataclass
148-
class AudioStreamMetadata:
149-
# TODO do we expose the notion of frame here, like in fps? It's technically
150-
# valid, but potentially is an FFmpeg-specific concept for audio
151-
# TODO Need sample rate and format
152-
sample_rate: Optional[int]
153-
duration_seconds_from_header: Optional[float]
154-
bit_rate: Optional[float]
155-
num_frames_from_header: Optional[int]
156-
num_frames_from_content: Optional[int]
157-
begin_stream_seconds_from_content: Optional[float]
158-
end_stream_seconds_from_content: Optional[float]
159-
codec: Optional[str]
160-
average_fps_from_header: Optional[float]
161-
stream_index: int
162-
163-
@property
164-
def num_frames(self) -> Optional[int]:
165-
"""Number of frames in the stream. This corresponds to
166-
``num_frames_from_content`` if a :term:`scan` was made, otherwise it
167-
corresponds to ``num_frames_from_header``.
168-
"""
169-
if self.num_frames_from_content is not None:
170-
return self.num_frames_from_content
171-
else:
172-
return self.num_frames_from_header
143+
class VideoStreamMetadata(StreamMetadata):
144+
"""Metadata of a single video stream."""
173145

174-
@property
175-
def duration_seconds(self) -> Optional[float]:
176-
"""Duration of the stream in seconds. We try to calculate the duration
177-
from the actual frames if a :term:`scan` was performed. Otherwise we
178-
fall back to ``duration_seconds_from_header``.
179-
"""
180-
if (
181-
self.end_stream_seconds_from_content is None
182-
or self.begin_stream_seconds_from_content is None
183-
):
184-
return self.duration_seconds_from_header
185-
return (
186-
self.end_stream_seconds_from_content
187-
- self.begin_stream_seconds_from_content
188-
)
146+
width: Optional[int]
147+
"""Width of the frames (int or None)."""
148+
height: Optional[int]
149+
"""Height of the frames (int or None)."""
189150

190-
@property
191-
def average_fps(self) -> Optional[float]:
192-
"""Average fps of the stream. If a :term:`scan` was perfomed, this is
193-
computed from the number of frames and the duration of the stream.
194-
Otherwise we fall back to ``average_fps_from_header``.
195-
"""
196-
if (
197-
self.end_stream_seconds_from_content is None
198-
or self.begin_stream_seconds_from_content is None
199-
or self.num_frames is None
200-
):
201-
return self.average_fps_from_header
202-
return self.num_frames / (
203-
self.end_stream_seconds_from_content
204-
- self.begin_stream_seconds_from_content
205-
)
151+
def __repr__(self):
152+
return super().__repr__()
206153

207-
@property
208-
def begin_stream_seconds(self) -> float:
209-
"""Beginning of the stream, in seconds (float). Conceptually, this
210-
corresponds to the first frame's :term:`pts`. If
211-
``begin_stream_seconds_from_content`` is not None, then it is returned.
212-
Otherwise, this value is 0.
213-
"""
214-
if self.begin_stream_seconds_from_content is None:
215-
return 0
216-
else:
217-
return self.begin_stream_seconds_from_content
218154

219-
@property
220-
def end_stream_seconds(self) -> Optional[float]:
221-
"""End of the stream, in seconds (float or None).
222-
Conceptually, this corresponds to last_frame.pts + last_frame.duration.
223-
If ``end_stream_seconds_from_content`` is not None, then that value is
224-
returned. Otherwise, returns ``duration_seconds``.
225-
"""
226-
if self.end_stream_seconds_from_content is None:
227-
return self.duration_seconds
228-
else:
229-
return self.end_stream_seconds_from_content
155+
@dataclass
156+
class AudioStreamMetadata(StreamMetadata):
157+
# TODO-AUDIO do we expose the notion of frame here, like in fps? It's technically
158+
# valid, but potentially is an FFmpeg-specific concept for audio
159+
# TODO-AUDIO Need sample rate and format and num_channels
160+
sample_rate: Optional[int]
230161

231162
def __repr__(self):
232-
# Overridden because properites are not printed by default.
233-
s = self.__class__.__name__ + ":\n"
234-
spaces = " "
235-
s += f"{spaces}num_frames: {self.num_frames}\n"
236-
s += f"{spaces}duration_seconds: {self.duration_seconds}\n"
237-
s += f"{spaces}average_fps: {self.average_fps}\n"
238-
for field in dataclasses.fields(self):
239-
s += f"{spaces}{field.name}: {getattr(self, field.name)}\n"
240-
return s
163+
return super().__repr__()
241164

242165

243166
@dataclass
244-
class VideoMetadata:
167+
class ContainerMetadata:
245168
duration_seconds_from_header: Optional[float]
246169
bit_rate_from_header: Optional[float]
247170
best_video_stream_index: Optional[int]
248171
best_audio_stream_index: Optional[int]
249172

250-
streams: List[Union[VideoStreamMetadata, AudioStreamMetadata]]
173+
streams: List[StreamMetadata]
251174

252175
@property
253176
def duration_seconds(self) -> Optional[float]:
@@ -266,15 +189,15 @@ def best_video_stream(self) -> VideoStreamMetadata:
266189
return metadata
267190

268191

269-
def get_video_metadata(decoder: torch.Tensor) -> VideoMetadata:
270-
"""Return video metadata from a video decoder.
192+
def get_container_metadata(decoder: torch.Tensor) -> ContainerMetadata:
193+
"""Return container metadata from a decoder.
271194
272195
The accuracy of the metadata and the availability of some returned fields
273196
depends on whether a full scan was performed by the decoder.
274197
"""
275198

276199
container_dict = json.loads(_get_container_json_metadata(decoder))
277-
streams_metadata: List[Union[VideoStreamMetadata, AudioStreamMetadata]] = []
200+
streams_metadata: List[StreamMetadata] = []
278201
for stream_index in range(container_dict["numStreams"]):
279202
stream_dict = json.loads(_get_stream_json_metadata(decoder, stream_index))
280203
common_meta = dict(
@@ -288,25 +211,29 @@ def get_video_metadata(decoder: torch.Tensor) -> VideoMetadata:
288211
average_fps_from_header=stream_dict.get("averageFps"),
289212
stream_index=stream_index,
290213
)
291-
if stream_dict["mediaType"] == "audio":
214+
if stream_dict["mediaType"] == "video":
292215
streams_metadata.append(
293-
AudioStreamMetadata(
294-
sample_rate=stream_dict.get("sampleRate"),
216+
VideoStreamMetadata(
217+
width=stream_dict.get("width"),
218+
height=stream_dict.get("height"),
295219
**common_meta,
296220
)
297221
)
298-
else:
299-
# TODO we're adding a VideoStreamMetadata for all non-audio streams,
300-
# including streams like subtitles, which makes little sense.
222+
elif stream_dict["mediaType"] == "audio":
301223
streams_metadata.append(
302-
VideoStreamMetadata(
303-
width=stream_dict.get("width"),
304-
height=stream_dict.get("height"),
224+
AudioStreamMetadata(
225+
sample_rate=stream_dict.get("sampleRate"),
305226
**common_meta,
306227
)
307228
)
229+
else:
230+
# This is neither a video nor audio stream. Could be e.g. subtitles.
231+
# We still need to add an entry to streams_metadata to keep its
232+
# length consistent with the number of streams, so we add a dummy
233+
# entry.
234+
streams_metadata.append(StreamMetadata(**common_meta))
308235

309-
return VideoMetadata(
236+
return ContainerMetadata(
310237
duration_seconds_from_header=container_dict.get("durationSeconds"),
311238
bit_rate_from_header=container_dict.get("bitRate"),
312239
best_video_stream_index=container_dict.get("bestVideoStreamIndex"),
@@ -315,5 +242,9 @@ def get_video_metadata(decoder: torch.Tensor) -> VideoMetadata:
315242
)
316243

317244

318-
def get_video_metadata_from_header(filename: Union[str, pathlib.Path]) -> VideoMetadata:
319-
return get_video_metadata(create_from_file(str(filename), seek_mode="approximate"))
245+
def get_container_metadata_from_header(
246+
filename: Union[str, pathlib.Path]
247+
) -> ContainerMetadata:
248+
return get_container_metadata(
249+
create_from_file(str(filename), seek_mode="approximate")
250+
)

0 commit comments

Comments
 (0)