Skip to content

Commit a2da767

Browse files
committed
More testing, docstring editing
1 parent 0d2492e commit a2da767

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

src/torchcodec/decoders/_video_decoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ class VideoDecoder:
6969
:ref:`sphx_glr_generated_examples_decoding_approximate_mode.py`
7070
transforms (sequence of transform objects, optional): Sequence of transforms to be
7171
applied to the decoded frames by the decoder itself, in order. Accepts both
72-
torchcodec.transforms.DecoderTransform and torchvision.transforms.v2.Transform
72+
``torchcodec.transforms.DecoderTransform`` and ``torchvision.transforms.v2.Transform``
7373
objects. All transforms are applied in the ouput pixel format and colorspace.
7474
custom_frame_mappings (str, bytes, or file-like object, optional):
7575
Mapping of frames to their metadata, typically generated via ffprobe.

test/test_transform_ops.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import pytest
1414

1515
import torch
16+
import torchcodec
1617

1718
from torchcodec._core import (
1819
_add_video_stream,
@@ -48,7 +49,12 @@ def test_resize_torchvision(
4849
height = int(video.get_height() * height_scaling_factor)
4950
width = int(video.get_width() * width_scaling_factor)
5051

52+
# We're using both the TorchCodec object and the TorchVision object to
53+
# ensure that they specify exactly the same thing.
5154
decoder_resize = VideoDecoder(
55+
video.path, transforms=[torchcodec.transforms.Resize(size=(height, width))]
56+
)
57+
decoder_resize_tv = VideoDecoder(
5258
video.path, transforms=[v2.Resize(size=(height, width))]
5359
)
5460

@@ -68,15 +74,18 @@ def test_resize_torchvision(
6874
int(num_frames * 0.90),
6975
num_frames - 1,
7076
]:
71-
expected_shape = (video.get_num_color_channels(), height, width)
77+
frame_resize_tv = decoder_resize_tv[frame_index]
7278
frame_resize = decoder_resize[frame_index]
79+
assert_frames_equal(frame_resize_tv, frame_resize)
80+
7381
frame_full = decoder_full[frame_index]
7482

7583
frame_tv = v2.functional.resize(frame_full, size=(height, width))
7684
frame_tv_no_antialias = v2.functional.resize(
7785
frame_full, size=(height, width), antialias=False
7886
)
7987

88+
expected_shape = (video.get_num_color_channels(), height, width)
8089
assert frame_resize.shape == expected_shape
8190
assert frame_tv.shape == expected_shape
8291
assert frame_tv_no_antialias.shape == expected_shape

0 commit comments

Comments
 (0)