Skip to content

Commit 59af4b7

Browse files
Add option for the user to pass in ffmpeg thread count (#291)
Co-authored-by: Nicolas Hug <[email protected]>
1 parent fedfeba commit 59af4b7

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

src/torchcodec/decoders/_video_decoder.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ class VideoDecoder:
3636
This can be either "NCHW" (default) or "NHWC", where N is the batch
3737
size, C is the number of channels, H is the height, and W is the
3838
width of the frames.
39+
num_ffmpeg_threads (int, optional): The number of threads to use for decoding.
40+
Use 1 for single-threaded decoding which may be best if you are running multiple
41+
instances of ``VideoDecoder`` in parallel. Use a higher number for multi-threaded
42+
decoding which is best if you are running a single instance of ``VideoDecoder``.
43+
Default: 1.
3944
4045
.. note::
4146
@@ -58,6 +63,7 @@ def __init__(
5863
*,
5964
stream_index: Optional[int] = None,
6065
dimension_order: Literal["NCHW", "NHWC"] = "NCHW",
66+
num_ffmpeg_threads: int = 1,
6167
):
6268
if isinstance(source, str):
6369
self._decoder = core.create_from_file(source)
@@ -82,7 +88,10 @@ def __init__(
8288

8389
core.scan_all_streams_to_update_metadata(self._decoder)
8490
core.add_video_stream(
85-
self._decoder, stream_index=stream_index, dimension_order=dimension_order
91+
self._decoder,
92+
stream_index=stream_index,
93+
dimension_order=dimension_order,
94+
num_threads=num_ffmpeg_threads,
8695
)
8796

8897
self.metadata, self.stream_index = _get_and_validate_stream_metadata(

test/decoders/test_video_decoder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,9 @@ def test_create_fails(self):
5555
with pytest.raises(ValueError, match="No valid stream found"):
5656
decoder = VideoDecoder(NASA_VIDEO.path, stream_index=1) # noqa
5757

58-
def test_getitem_int(self):
59-
decoder = VideoDecoder(NASA_VIDEO.path)
58+
@pytest.mark.parametrize("num_ffmpeg_threads", (1, 4))
59+
def test_getitem_int(self, num_ffmpeg_threads):
60+
decoder = VideoDecoder(NASA_VIDEO.path, num_ffmpeg_threads=num_ffmpeg_threads)
6061

6162
ref_frame0 = NASA_VIDEO.get_frame_data_by_index(0)
6263
ref_frame1 = NASA_VIDEO.get_frame_data_by_index(1)

0 commit comments

Comments
 (0)