Skip to content

Commit d841909

Browse files
authored
Add stream_index as an option to VideoDecoder (#254)
1 parent 5a674cf commit d841909

File tree

41 files changed

+295
-113
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+295
-113
lines changed

src/torchcodec/decoders/_video_decoder.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import numbers
88
from pathlib import Path
9-
from typing import Literal, Tuple, Union
9+
from typing import Literal, Optional, Tuple, Union
1010

1111
from torch import Tensor
1212

@@ -22,14 +22,16 @@
2222
class VideoDecoder:
2323
"""A single-stream video decoder.
2424
25-
If the video contains multiple video streams, the :term:`best stream` is
26-
used. This decoder always performs a :term:`scan` of the video.
25+
This decoder always performs a :term:`scan` of the video.
2726
2827
Args:
2928
source (str, ``Pathlib.path``, ``torch.Tensor``, or bytes): The source of the video.
3029
3130
- If ``str`` or ``Pathlib.path``: a path to a local video file.
3231
- If ``bytes`` object or ``torch.Tensor``: the raw encoded video data.
32+
stream_index (int, optional): Specifies which stream in the video to decode frames from.
33+
Note that this index is absolute across all media types. If left unspecified, then
34+
the :term:`best stream` is used.
3335
dimension_order(str, optional): The dimension order of the decoded frames.
3436
This can be either "NCHW" (default) or "NHWC", where N is the batch
3537
size, C is the number of channels, H is the height, and W is the
@@ -45,11 +47,16 @@ class VideoDecoder:
4547
4648
Attributes:
4749
metadata (VideoStreamMetadata): Metadata of the video stream.
50+
stream_index (int): The stream index that this decoder is retrieving frames from. If a
51+
stream index was provided at initialization, this is the same value. If it was left
52+
unspecified, this is the :term:`best stream`.
4853
"""
4954

5055
def __init__(
5156
self,
5257
source: Union[str, Path, bytes, Tensor],
58+
*,
59+
stream_index: Optional[int] = None,
5360
dimension_order: Literal["NCHW", "NHWC"] = "NCHW",
5461
):
5562
if isinstance(source, str):
@@ -74,10 +81,12 @@ def __init__(
7481
)
7582

7683
core.scan_all_streams_to_update_metadata(self._decoder)
77-
core.add_video_stream(self._decoder, dimension_order=dimension_order)
84+
core.add_video_stream(
85+
self._decoder, stream_index=stream_index, dimension_order=dimension_order
86+
)
7887

79-
self.metadata, self._stream_index = _get_and_validate_stream_metadata(
80-
self._decoder
88+
self.metadata, self.stream_index = _get_and_validate_stream_metadata(
89+
self._decoder, stream_index
8190
)
8291

8392
if self.metadata.num_frames_from_content is None:
@@ -114,7 +123,7 @@ def _getitem_int(self, key: int) -> Tensor:
114123
)
115124

116125
frame_data, *_ = core.get_frame_at_index(
117-
self._decoder, frame_index=key, stream_index=self._stream_index
126+
self._decoder, frame_index=key, stream_index=self.stream_index
118127
)
119128
return frame_data
120129

@@ -124,7 +133,7 @@ def _getitem_slice(self, key: slice) -> Tensor:
124133
start, stop, step = key.indices(len(self))
125134
frame_data, *_ = core.get_frames_in_range(
126135
self._decoder,
127-
stream_index=self._stream_index,
136+
stream_index=self.stream_index,
128137
start=start,
129138
stop=stop,
130139
step=step,
@@ -164,7 +173,7 @@ def get_frame_at(self, index: int) -> Frame:
164173
f"Index {index} is out of bounds; must be in the range [0, {self._num_frames})."
165174
)
166175
data, pts_seconds, duration_seconds = core.get_frame_at_index(
167-
self._decoder, frame_index=index, stream_index=self._stream_index
176+
self._decoder, frame_index=index, stream_index=self.stream_index
168177
)
169178
return Frame(
170179
data=data,
@@ -198,7 +207,7 @@ def get_frames_at(self, start: int, stop: int, step: int = 1) -> FrameBatch:
198207
raise IndexError(f"Step ({step}) must be greater than 0.")
199208
frames = core.get_frames_in_range(
200209
self._decoder,
201-
stream_index=self._stream_index,
210+
stream_index=self.stream_index,
202211
start=start,
203212
stop=stop,
204213
step=step,
@@ -264,7 +273,7 @@ def get_frames_displayed_at(
264273
)
265274
frames = core.get_frames_by_pts_in_range(
266275
self._decoder,
267-
stream_index=self._stream_index,
276+
stream_index=self.stream_index,
268277
start_seconds=start_seconds,
269278
stop_seconds=stop_seconds,
270279
)
@@ -273,14 +282,22 @@ def get_frames_displayed_at(
273282

274283
def _get_and_validate_stream_metadata(
275284
decoder: Tensor,
285+
stream_index: Optional[int] = None,
276286
) -> Tuple[core.VideoStreamMetadata, int]:
277287
video_metadata = core.get_video_metadata(decoder)
278288

279-
best_stream_index = video_metadata.best_video_stream_index
280-
if best_stream_index is None:
281-
raise ValueError(
282-
"The best video stream is unknown. " + _ERROR_REPORTING_INSTRUCTIONS
283-
)
289+
if stream_index is None:
290+
best_stream_index = video_metadata.best_video_stream_index
291+
if best_stream_index is None:
292+
raise ValueError(
293+
"The best video stream is unknown and there is no specified stream. "
294+
+ _ERROR_REPORTING_INSTRUCTIONS
295+
)
296+
stream_index = best_stream_index
297+
298+
# This should be logically true because of the above conditions, but type checker
299+
# is not clever enough.
300+
assert stream_index is not None
284301

285-
best_stream_metadata = video_metadata.streams[best_stream_index]
286-
return (best_stream_metadata, best_stream_index)
302+
stream_metadata = video_metadata.streams[stream_index]
303+
return (stream_metadata, stream_index)

test/decoders/VideoDecoderTest.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,9 @@ TEST_P(VideoDecoderTest, ReturnsFirstTwoFramesOfVideo) {
180180
EXPECT_EQ(output.pts, 1001);
181181

182182
torch::Tensor tensor0FromFFMPEG =
183-
readTensorFromDisk("nasa_13013.mp4.frame000000.pt");
183+
readTensorFromDisk("nasa_13013.mp4.stream3.frame000000.pt");
184184
torch::Tensor tensor1FromFFMPEG =
185-
readTensorFromDisk("nasa_13013.mp4.frame000001.pt");
185+
readTensorFromDisk("nasa_13013.mp4.stream3.frame000001.pt");
186186

187187
EXPECT_EQ(tensor1FromFFMPEG.sizes(), std::vector<long>({3, 270, 480}));
188188
EXPECT_TRUE(torch::equal(tensor0FromOurDecoder, tensor0FromFFMPEG));
@@ -215,7 +215,7 @@ TEST_P(VideoDecoderTest, DecodesFramesInABatchInNCHW) {
215215
EXPECT_EQ(tensor.sizes(), std::vector<long>({2, 3, 270, 480}));
216216

217217
torch::Tensor tensor0FromFFMPEG =
218-
readTensorFromDisk("nasa_13013.mp4.frame000000.pt");
218+
readTensorFromDisk("nasa_13013.mp4.stream3.frame000000.pt");
219219
torch::Tensor tensorTime6FromFFMPEG =
220220
readTensorFromDisk("nasa_13013.mp4.time6.000000.pt");
221221

@@ -239,7 +239,7 @@ TEST_P(VideoDecoderTest, DecodesFramesInABatchInNHWC) {
239239
EXPECT_EQ(tensor.sizes(), std::vector<long>({2, 270, 480, 3}));
240240

241241
torch::Tensor tensor0FromFFMPEG =
242-
readTensorFromDisk("nasa_13013.mp4.frame000000.pt");
242+
readTensorFromDisk("nasa_13013.mp4.stream3.frame000000.pt");
243243
torch::Tensor tensorTime6FromFFMPEG =
244244
readTensorFromDisk("nasa_13013.mp4.time6.000000.pt");
245245

0 commit comments

Comments
 (0)