Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/torchcodec/decoders/_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pathlib import Path
from typing import Literal, Optional, Tuple, Union

from torch import device, Tensor
from torch import device as torch_device, Tensor

from torchcodec import Frame, FrameBatch
from torchcodec.decoders import _core as core
Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(
stream_index: Optional[int] = None,
dimension_order: Literal["NCHW", "NHWC"] = "NCHW",
num_ffmpeg_threads: int = 1,
device: Optional[Union[str, device]] = "cpu",
device: Optional[Union[str, torch_device]] = "cpu",
seek_mode: Literal["exact", "approximate"] = "exact",
):
allowed_seek_modes = ("exact", "approximate")
Expand All @@ -94,6 +94,9 @@ def __init__(
if num_ffmpeg_threads is None:
raise ValueError(f"{num_ffmpeg_threads = } should be an int.")

if isinstance(device, torch_device):
device = str(device)

core.add_video_stream(
self._decoder,
stream_index=stream_index,
Expand Down
Loading