Skip to content

Commit 591995f

Browse files
committed
Support both RawIOBase and BytesIO
1 parent e9a726f commit 591995f

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

src/torchcodec/decoders/_core/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def create_from_bytes(
137137

138138

139139
def create_from_file_like(
140-
file_like: io.RawIOBase, seek_mode: Optional[str] = None
140+
file_like: io.RawIOBase | io.BytesIO, seek_mode: Optional[str] = None
141141
) -> torch.Tensor:
142142
assert _pybind_ops is not None
143143
return _convert_to_tensor(_pybind_ops.create_from_file_like(file_like, seek_mode))

test/decoders/test_ops.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,10 @@ def get_frame1_and_frame_time6(decoder):
343343
assert_frames_equal(frame_time6, reference_frame_time6.to(device))
344344

345345
@pytest.mark.parametrize("device", cpu_and_cuda())
346-
@pytest.mark.parametrize("create_from", ("file", "tensor", "bytes", "file_like"))
346+
@pytest.mark.parametrize(
347+
"create_from",
348+
("file", "tensor", "bytes", "file_like_rawio", "file_like_bufferedio"),
349+
)
347350
def test_create_decoder(self, create_from, device):
348351
path = str(NASA_VIDEO.path)
349352
if create_from == "file":
@@ -356,8 +359,12 @@ def test_create_decoder(self, create_from, device):
356359
with open(path, "rb") as f:
357360
video_bytes = f.read()
358361
decoder = create_from_bytes(video_bytes)
359-
elif create_from == "file_like":
362+
elif create_from == "file_like_rawio":
360363
decoder = create_from_file_like(open(path, mode="rb", buffering=0), "exact")
364+
elif create_from == "file_like_bufferedio":
365+
decoder = create_from_file_like(
366+
open(path, mode="rb", buffering=-4096), "exact"
367+
)
361368
else:
362369
raise ValueError("Oops, double check the parametrization of this test!")
363370

0 commit comments

Comments
 (0)