diff --git a/docs/source/api_ref_transforms.rst b/docs/source/api_ref_transforms.rst index 04ef28ab9..e67cbddc6 100644 --- a/docs/source/api_ref_transforms.rst +++ b/docs/source/api_ref_transforms.rst @@ -14,4 +14,5 @@ For a tutorial, see: TODO_DECODER_TRANSFORMS_TUTORIAL. :template: dataclass.rst DecoderTransform + RandomCrop Resize diff --git a/src/torchcodec/transforms/_decoder_transforms.py b/src/torchcodec/transforms/_decoder_transforms.py index ed38820b1..f3f0c45a3 100644 --- a/src/torchcodec/transforms/_decoder_transforms.py +++ b/src/torchcodec/transforms/_decoder_transforms.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. from abc import ABC, abstractmethod -from dataclasses import dataclass from types import ModuleType from typing import Optional, Sequence, Tuple @@ -13,7 +12,6 @@ from torch import nn -@dataclass class DecoderTransform(ABC): """Base class for all decoder transforms. @@ -91,7 +89,6 @@ def import_torchvision_transforms_v2() -> ModuleType: return v2 -@dataclass class Resize(DecoderTransform): """Resize the decoded frame to a given size. @@ -103,18 +100,20 @@ class Resize(DecoderTransform): the form (height, width). """ - size: Sequence[int] + def __init__(self, size: Sequence[int]): + if len(size) != 2: + raise ValueError( + "Resize transform must have a (height, width) " + f"pair for the size, got {size}." + ) + self.size = size def _make_transform_spec( self, input_dims: Tuple[Optional[int], Optional[int]] ) -> str: - # TODO: establish this invariant in the constructor during refactor - assert len(self.size) == 2 return f"resize, {self.size[0]}, {self.size[1]}" def _get_output_dims(self) -> Optional[Tuple[Optional[int], Optional[int]]]: - # TODO: establish this invariant in the constructor during refactor - assert len(self.size) == 2 return (self.size[0], self.size[1]) @classmethod @@ -141,7 +140,6 @@ def _from_torchvision(cls, tv_resize: nn.Module): return cls(size=tv_resize.size) -@dataclass class RandomCrop(DecoderTransform): """Crop the decoded frame to a given size at a random location in the frame. @@ -158,17 +156,17 @@ class RandomCrop(DecoderTransform): the form (height, width). """ - size: Sequence[int] + def __init__(self, size: Sequence[int]): + if len(size) != 2: + raise ValueError( + "RandomCrop transform must have a (height, width) " + f"pair for the size, got {size}." + ) + self.size = size def _make_transform_spec( self, input_dims: Tuple[Optional[int], Optional[int]] ) -> str: - if len(self.size) != 2: - raise ValueError( - f"RandomCrop's size must be a sequence of length 2, got {self.size}. " - "This should never happen, please report a bug." - ) - height, width = input_dims if height is None: raise ValueError( @@ -196,8 +194,6 @@ def _make_transform_spec( return f"crop, {self.size[0]}, {self.size[1]}, {left}, {top}" def _get_output_dims(self) -> Optional[Tuple[Optional[int], Optional[int]]]: - # TODO: establish this invariant in the constructor during refactor - assert len(self.size) == 2 return (self.size[0], self.size[1]) @classmethod diff --git a/test/test_transform_ops.py b/test/test_transform_ops.py index 5839f79a4..8aa151b01 100644 --- a/test/test_transform_ops.py +++ b/test/test_transform_ops.py @@ -145,6 +145,15 @@ def test_resize_fails(self): ): VideoDecoder(NASA_VIDEO.path, transforms=[v2.Resize(size=(100))]) + with pytest.raises( + ValueError, + match=r"must have a \(height, width\) pair for the size", + ): + VideoDecoder( + NASA_VIDEO.path, + transforms=[torchcodec.transforms.Resize(size=(100, 100, 100))], + ) + @pytest.mark.parametrize( "height_scaling_factor, width_scaling_factor", ((0.5, 0.5), (0.25, 0.1), (1.0, 1.0), (0.15, 0.75)),