Skip to content

Commit d3b4e05

Browse files
committed
Add audio metadata
1 parent da9164e commit d3b4e05

File tree

13 files changed

+485
-73
lines changed

13 files changed

+485
-73
lines changed

src/torchcodec/decoders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,6 @@
77
from ._core import VideoStreamMetadata
88
from ._video_decoder import VideoDecoder # noqa
99

10+
# from ._audio_decoder import AudioDecoder # Will be public when more stable
11+
1012
SimpleVideoDecoder = VideoDecoder
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from pathlib import Path
8+
from typing import Literal, Optional, Tuple, Union
9+
10+
from torch import Tensor
11+
12+
from torchcodec.decoders import _core as core
13+
14+
_ERROR_REPORTING_INSTRUCTIONS = """
15+
This should never happen. Please report an issue following the steps in
16+
https://github.com/pytorch/torchcodec/issues/new?assignees=&labels=&projects=&template=bug-report.yml.
17+
"""
18+
19+
20+
class AudioDecoder:
21+
"""A single-stream audio decoder.
22+
23+
TODO docs
24+
"""
25+
26+
def __init__(
27+
self,
28+
source: Union[str, Path, bytes, Tensor],
29+
*,
30+
sample_rate: Optional[int] = None,
31+
stream_index: Optional[int] = None,
32+
seek_mode: Literal["exact", "approximate"] = "exact",
33+
):
34+
if sample_rate is not None:
35+
raise ValueError("TODO implement this")
36+
37+
# TODO unify validation with VideoDecoder?
38+
allowed_seek_modes = ("exact", "approximate")
39+
if seek_mode not in allowed_seek_modes:
40+
raise ValueError(
41+
f"Invalid seek mode ({seek_mode}). "
42+
f"Supported values are {', '.join(allowed_seek_modes)}."
43+
)
44+
45+
if isinstance(source, str):
46+
self._decoder = core.create_from_file(source, seek_mode)
47+
elif isinstance(source, Path):
48+
self._decoder = core.create_from_file(str(source), seek_mode)
49+
elif isinstance(source, bytes):
50+
self._decoder = core.create_from_bytes(source, seek_mode)
51+
elif isinstance(source, Tensor):
52+
self._decoder = core.create_from_tensor(source, seek_mode)
53+
else:
54+
raise TypeError(
55+
f"Unknown source type: {type(source)}. "
56+
"Supported types are str, Path, bytes and Tensor."
57+
)
58+
59+
core.add_audio_stream(self._decoder, stream_index=stream_index)
60+
61+
self.metadata, self.stream_index = _get_and_validate_stream_metadata(
62+
self._decoder, stream_index
63+
)
64+
65+
# if self.metadata.num_frames is None:
66+
# raise ValueError(
67+
# "The number of frames is unknown. " + _ERROR_REPORTING_INSTRUCTIONS
68+
# )
69+
# self._num_frames = self.metadata.num_frames
70+
71+
# if self.metadata.begin_stream_seconds is None:
72+
# raise ValueError(
73+
# "The minimum pts value in seconds is unknown. "
74+
# + _ERROR_REPORTING_INSTRUCTIONS
75+
# )
76+
# self._begin_stream_seconds = self.metadata.begin_stream_seconds
77+
78+
# if self.metadata.end_stream_seconds is None:
79+
# raise ValueError(
80+
# "The maximum pts value in seconds is unknown. "
81+
# + _ERROR_REPORTING_INSTRUCTIONS
82+
# )
83+
# self._end_stream_seconds = self.metadata.end_stream_seconds
84+
85+
# TODO we need to have a default for stop_seconds.
86+
def get_samples_played_in_range(
87+
self, start_seconds: float, stop_seconds: float
88+
) -> Tensor:
89+
"""
90+
TODO DOCS
91+
"""
92+
# if not start_seconds <= stop_seconds:
93+
# raise ValueError(
94+
# f"Invalid start seconds: {start_seconds}. It must be less than or equal to stop seconds ({stop_seconds})."
95+
# )
96+
# if not self._begin_stream_seconds <= start_seconds < self._end_stream_seconds:
97+
# raise ValueError(
98+
# f"Invalid start seconds: {start_seconds}. "
99+
# f"It must be greater than or equal to {self._begin_stream_seconds} "
100+
# f"and less than or equal to {self._end_stream_seconds}."
101+
# )
102+
# if not stop_seconds <= self._end_stream_seconds:
103+
# raise ValueError(
104+
# f"Invalid stop seconds: {stop_seconds}. "
105+
# f"It must be less than or equal to {self._end_stream_seconds}."
106+
# )
107+
108+
frames, *_ = core.get_frames_by_pts_in_range(
109+
self._decoder,
110+
start_seconds=start_seconds,
111+
stop_seconds=stop_seconds,
112+
)
113+
# TODO need to return view on this to account for samples instead of
114+
# frames
115+
return frames
116+
117+
118+
def _get_and_validate_stream_metadata(
119+
decoder: Tensor,
120+
stream_index: Optional[int] = None,
121+
) -> Tuple[core.AudioStreamMetadata, int]:
122+
123+
# TODO should this still be called `get_video_metadata`?
124+
container_metadata = core.get_video_metadata(decoder)
125+
126+
if stream_index is None:
127+
best_stream_index = container_metadata.best_audio_stream_index
128+
if best_stream_index is None:
129+
raise ValueError(
130+
"The best audio stream is unknown and there is no specified stream. "
131+
+ _ERROR_REPORTING_INSTRUCTIONS
132+
)
133+
stream_index = best_stream_index
134+
135+
# This should be logically true because of the above conditions, but type checker
136+
# is not clever enough.
137+
assert stream_index is not None
138+
139+
stream_metadata = container_metadata.streams[stream_index]
140+
return (stream_metadata, stream_index)

src/torchcodec/decoders/_core/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ set(CMAKE_CXX_STANDARD 17)
44
set(CMAKE_CXX_STANDARD_REQUIRED ON)
55

66
find_package(Torch REQUIRED)
7-
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Werror ${TORCH_CXX_FLAGS}")
7+
# TODO Put back normal flags
8+
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Werror ${TORCH_CXX_FLAGS}")
9+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall ${TORCH_CXX_FLAGS}")
810
find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development)
911

1012
function(make_torchcodec_library library_name ffmpeg_target)

src/torchcodec/decoders/_core/FFMPEGCommon.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,26 @@ int64_t getDuration(const AVFrame* frame) {
6060
#endif
6161
}
6262

63+
int getNumChannels(const AVFrame* avFrame) {
64+
#if LIBAVFILTER_VERSION_MAJOR > 8 || \
65+
(IBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44)
66+
return avFrame->ch_layout.nb_channels;
67+
#else
68+
return av_get_channel_layout_nb_channels(avFrame->channel_layout);
69+
#endif
70+
}
71+
72+
int getNumChannels(const UniqueAVCodecContext& avCodecContext) {
73+
// Not sure about the exactness of the version bounds, but as long as this
74+
// compile we're fine.
75+
#if LIBAVFILTER_VERSION_MAJOR > 8 || \
76+
(IBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44)
77+
return avCodecContext->ch_layout.nb_channels;
78+
#else
79+
return avCodecContext->channels;
80+
#endif
81+
}
82+
6383
AVIOBytesContext::AVIOBytesContext(
6484
const void* data,
6585
size_t dataSize,

src/torchcodec/decoders/_core/FFMPEGCommon.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,9 @@ std::string getFFMPEGErrorStringFromErrorCode(int errorCode);
139139
int64_t getDuration(const UniqueAVFrame& frame);
140140
int64_t getDuration(const AVFrame* frame);
141141

142+
int getNumChannels(const AVFrame* avFrame);
143+
int getNumChannels(const UniqueAVCodecContext& avCodecContext);
144+
142145
// Returns true if sws_scale can handle unaligned data.
143146
bool canSwsScaleHandleUnalignedData();
144147

0 commit comments

Comments
 (0)