Skip to content

Commit 0e01169

Browse files
author
Daniel Flores
committed
fix tensor dtypes, add test
1 parent b3fe430 commit 0e01169

File tree

3 files changed

+34
-13
lines changed

3 files changed

+34
-13
lines changed

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,11 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() {
322322
void SingleStreamDecoder::readCustomFrameMappingsUpdateMetadataAndIndex(
323323
int streamIndex,
324324
FrameMappings customFrameMappings) {
325+
TORCH_CHECK(
326+
customFrameMappings.all_frames.dtype() == torch::kLong &&
327+
customFrameMappings.is_key_frame.dtype() == torch::kBool &&
328+
customFrameMappings.duration.dtype() == torch::kLong,
329+
"all_frames and duration tensors must be int64 dtype, and is_key_frame tensor must be a bool dtype.");
325330
const torch::Tensor& all_frames =
326331
customFrameMappings.all_frames.to(torch::kLong);
327332
const torch::Tensor& is_key_frame =

src/torchcodec/decoders/_video_decoder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -476,10 +476,12 @@ def _read_custom_frame_mappings(
476476
)
477477

478478
frame_data = [
479-
(float(frame[pts_key]), frame["key_frame"], float(frame[duration_key]))
479+
(int(frame[pts_key]), frame["key_frame"], int(frame[duration_key]))
480480
for frame in input_data["frames"]
481481
]
482-
all_frames, is_key_frame, duration = map(torch.tensor, zip(*frame_data))
482+
all_frames = torch.tensor([x[0] for x in frame_data], dtype=torch.int64)
483+
is_key_frame = torch.tensor([x[1] for x in frame_data], dtype=torch.bool)
484+
duration = torch.tensor([x[2] for x in frame_data], dtype=torch.int64)
483485
if not (len(all_frames) == len(is_key_frame) == len(duration)):
484486
raise ValueError("Mismatched lengths in frame index data")
485487
return all_frames, is_key_frame, duration

test/test_ops.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -481,27 +481,41 @@ def test_frame_pts_equality(self):
481481
assert pts_is_equal
482482

483483
def test_seek_mode_custom_frame_mappings_fails(self):
484-
decoder = create_from_file(
485-
str(NASA_VIDEO.path), seek_mode="custom_frame_mappings"
486-
)
487484
with pytest.raises(
488485
RuntimeError,
489486
match="Missing frame mappings when custom_frame_mappings seek mode is set.",
490487
):
488+
decoder = create_from_file(
489+
str(NASA_VIDEO.path), seek_mode="custom_frame_mappings"
490+
)
491491
add_video_stream(decoder, stream_index=0, custom_frame_mappings=None)
492492

493-
decoder = create_from_file(
494-
str(NASA_VIDEO.path), seek_mode="custom_frame_mappings"
495-
)
496-
different_lengths = (
497-
torch.tensor([1, 2, 3]),
498-
torch.tensor([1, 2]),
499-
torch.tensor([1, 2, 3]),
500-
)
493+
with pytest.raises(
494+
RuntimeError,
495+
match="all_frames and duration tensors must be int64 dtype, and is_key_frame tensor must be a bool dtype.",
496+
):
497+
decoder = create_from_file(
498+
str(NASA_VIDEO.path), seek_mode="custom_frame_mappings"
499+
)
500+
wrong_types = (
501+
torch.tensor([1.1, 2.2, 3.3]),
502+
torch.tensor([1, 2]),
503+
torch.tensor([1, 2, 3]),
504+
)
505+
add_video_stream(decoder, stream_index=0, custom_frame_mappings=wrong_types)
506+
501507
with pytest.raises(
502508
RuntimeError,
503509
match="all_frames, is_key_frame, and duration from custom_frame_mappings were not same size.",
504510
):
511+
decoder = create_from_file(
512+
str(NASA_VIDEO.path), seek_mode="custom_frame_mappings"
513+
)
514+
different_lengths = (
515+
torch.tensor([1, 2, 3]),
516+
torch.tensor([False, False]),
517+
torch.tensor([1, 2, 3]),
518+
)
505519
add_video_stream(
506520
decoder, stream_index=0, custom_frame_mappings=different_lengths
507521
)

0 commit comments

Comments
 (0)