66
77import numbers
88from pathlib import Path
9- from typing import Literal , Tuple , Union
9+ from typing import Literal , Optional , Tuple , Union
1010
1111from torch import Tensor
1212
2222class VideoDecoder :
2323 """A single-stream video decoder.
2424
25- If the video contains multiple video streams, the :term:`best stream` is
26- used. This decoder always performs a :term:`scan` of the video.
25+ This decoder always performs a :term:`scan` of the video.
2726
2827 Args:
2928 source (str, ``Pathlib.path``, ``torch.Tensor``, or bytes): The source of the video.
3029
3130 - If ``str`` or ``Pathlib.path``: a path to a local video file.
3231 - If ``bytes`` object or ``torch.Tensor``: the raw encoded video data.
32+ stream_index (int, optional): Specifies which stream in the video to decode frames from.
33+ Note that this index is absolute across all media types. If left unspecified, then
34+ the :term:`best stream` is used.
3335 dimension_order(str, optional): The dimension order of the decoded frames.
3436 This can be either "NCHW" (default) or "NHWC", where N is the batch
3537 size, C is the number of channels, H is the height, and W is the
@@ -45,11 +47,16 @@ class VideoDecoder:
4547
4648 Attributes:
4749 metadata (VideoStreamMetadata): Metadata of the video stream.
50+ stream_index (int): The stream index that this decoder is retrieving frames from. If a
51+ stream index was provided at initialization, this is the same value. If it was left
52+ unspecified, this is the :term:`best stream`.
4853 """
4954
5055 def __init__ (
5156 self ,
5257 source : Union [str , Path , bytes , Tensor ],
58+ * ,
59+ stream_index : Optional [int ] = None ,
5360 dimension_order : Literal ["NCHW" , "NHWC" ] = "NCHW" ,
5461 ):
5562 if isinstance (source , str ):
@@ -74,10 +81,12 @@ def __init__(
7481 )
7582
7683 core .scan_all_streams_to_update_metadata (self ._decoder )
77- core .add_video_stream (self ._decoder , dimension_order = dimension_order )
84+ core .add_video_stream (
85+ self ._decoder , stream_index = stream_index , dimension_order = dimension_order
86+ )
7887
79- self .metadata , self ._stream_index = _get_and_validate_stream_metadata (
80- self ._decoder
88+ self .metadata , self .stream_index = _get_and_validate_stream_metadata (
89+ self ._decoder , stream_index
8190 )
8291
8392 if self .metadata .num_frames_from_content is None :
@@ -114,7 +123,7 @@ def _getitem_int(self, key: int) -> Tensor:
114123 )
115124
116125 frame_data , * _ = core .get_frame_at_index (
117- self ._decoder , frame_index = key , stream_index = self ._stream_index
126+ self ._decoder , frame_index = key , stream_index = self .stream_index
118127 )
119128 return frame_data
120129
@@ -124,7 +133,7 @@ def _getitem_slice(self, key: slice) -> Tensor:
124133 start , stop , step = key .indices (len (self ))
125134 frame_data , * _ = core .get_frames_in_range (
126135 self ._decoder ,
127- stream_index = self ._stream_index ,
136+ stream_index = self .stream_index ,
128137 start = start ,
129138 stop = stop ,
130139 step = step ,
@@ -164,7 +173,7 @@ def get_frame_at(self, index: int) -> Frame:
164173 f"Index { index } is out of bounds; must be in the range [0, { self ._num_frames } )."
165174 )
166175 data , pts_seconds , duration_seconds = core .get_frame_at_index (
167- self ._decoder , frame_index = index , stream_index = self ._stream_index
176+ self ._decoder , frame_index = index , stream_index = self .stream_index
168177 )
169178 return Frame (
170179 data = data ,
@@ -198,7 +207,7 @@ def get_frames_at(self, start: int, stop: int, step: int = 1) -> FrameBatch:
198207 raise IndexError (f"Step ({ step } ) must be greater than 0." )
199208 frames = core .get_frames_in_range (
200209 self ._decoder ,
201- stream_index = self ._stream_index ,
210+ stream_index = self .stream_index ,
202211 start = start ,
203212 stop = stop ,
204213 step = step ,
@@ -264,7 +273,7 @@ def get_frames_displayed_at(
264273 )
265274 frames = core .get_frames_by_pts_in_range (
266275 self ._decoder ,
267- stream_index = self ._stream_index ,
276+ stream_index = self .stream_index ,
268277 start_seconds = start_seconds ,
269278 stop_seconds = stop_seconds ,
270279 )
@@ -273,14 +282,22 @@ def get_frames_displayed_at(
273282
274283def _get_and_validate_stream_metadata (
275284 decoder : Tensor ,
285+ stream_index : Optional [int ] = None ,
276286) -> Tuple [core .VideoStreamMetadata , int ]:
277287 video_metadata = core .get_video_metadata (decoder )
278288
279- best_stream_index = video_metadata .best_video_stream_index
280- if best_stream_index is None :
281- raise ValueError (
282- "The best video stream is unknown. " + _ERROR_REPORTING_INSTRUCTIONS
283- )
289+ if stream_index is None :
290+ best_stream_index = video_metadata .best_video_stream_index
291+ if best_stream_index is None :
292+ raise ValueError (
293+ "The best video stream is unknown and there is no specified stream. "
294+ + _ERROR_REPORTING_INSTRUCTIONS
295+ )
296+ stream_index = best_stream_index
297+
298+ # This should be logically true because of the above conditions, but type checker
299+ # is not clever enough.
300+ assert stream_index is not None
284301
285- best_stream_metadata = video_metadata .streams [best_stream_index ]
286- return (best_stream_metadata , best_stream_index )
302+ stream_metadata = video_metadata .streams [stream_index ]
303+ return (stream_metadata , stream_index )
0 commit comments