-
Notifications
You must be signed in to change notification settings - Fork 71
Implement RandomCrop transform #1070
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
a8a8cea
af2e1ab
aa15765
fd8f7a5
7e43313
8e6a8f2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,7 @@ | |
| create_decoder, | ||
| ERROR_REPORTING_INSTRUCTIONS, | ||
| ) | ||
| from torchcodec.transforms import DecoderTransform, Resize | ||
| from torchcodec.transforms import DecoderTransform, RandomCrop, Resize | ||
|
|
||
|
|
||
| class VideoDecoder: | ||
|
|
@@ -167,7 +167,11 @@ def __init__( | |
| device = str(device) | ||
|
|
||
| device_variant = _get_cuda_backend() | ||
| transform_specs = _make_transform_specs(transforms) | ||
| transform_specs = _make_transform_specs( | ||
| transforms, | ||
| input_dims=(self.metadata.height, self.metadata.width), | ||
| dimension_order=dimension_order, | ||
| ) | ||
|
|
||
| core.add_video_stream( | ||
| self._decoder, | ||
|
|
@@ -450,6 +454,8 @@ def _get_and_validate_stream_metadata( | |
|
|
||
| def _convert_to_decoder_transforms( | ||
| transforms: Sequence[Union[DecoderTransform, nn.Module]], | ||
| input_dims: Tuple[Optional[int], Optional[int]], | ||
| dimension_order: Literal["NCHW", "NHWC"], | ||
| ) -> List[DecoderTransform]: | ||
| """Convert a sequence of transforms that may contain TorchVision transform | ||
| objects into a list of only TorchCodec transform objects. | ||
|
|
@@ -482,21 +488,39 @@ def _convert_to_decoder_transforms( | |
| "v2 transforms, but TorchVision is not installed." | ||
| ) | ||
| elif isinstance(transform, v2.Resize): | ||
| converted_transforms.append(Resize._from_torchvision(transform)) | ||
| transform_tc = Resize._from_torchvision(transform) | ||
| input_dims = transform_tc._get_output_dims(input_dims) | ||
| converted_transforms.append(transform_tc) | ||
| elif isinstance(transform, v2.RandomCrop): | ||
| if dimension_order != "NCHW": | ||
| raise ValueError( | ||
| "TorchVision v2 RandomCrop is only supported for NCHW " | ||
| "dimension order. Please use the TorchCodec RandomCrop " | ||
| "transform instead." | ||
| ) | ||
| transform_tc = RandomCrop._from_torchvision( | ||
| transform, | ||
| input_dims, | ||
| ) | ||
| input_dims = transform_tc._get_output_dims(input_dims) | ||
| converted_transforms.append(transform_tc) | ||
| else: | ||
| raise ValueError( | ||
| f"Unsupported transform: {transform}. Transforms must be " | ||
| "either a TorchCodec DecoderTransform or a TorchVision " | ||
| "v2 transform." | ||
| ) | ||
| else: | ||
| input_dims = transform._get_output_dims(input_dims) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not actually used for validation. Think of the transforms as a pipeline: The dimensions that This probably deserves a comment. :) |
||
| converted_transforms.append(transform) | ||
|
|
||
| return converted_transforms | ||
|
|
||
|
|
||
| def _make_transform_specs( | ||
| transforms: Optional[Sequence[Union[DecoderTransform, nn.Module]]], | ||
| input_dims: Tuple[Optional[int], Optional[int]], | ||
| dimension_order: Literal["NCHW", "NHWC"], | ||
| ) -> str: | ||
| """Given a sequence of transforms, turn those into the specification string | ||
| the core API expects. | ||
|
|
@@ -516,7 +540,7 @@ def _make_transform_specs( | |
| if transforms is None: | ||
| return "" | ||
|
|
||
| transforms = _convert_to_decoder_transforms(transforms) | ||
| transforms = _convert_to_decoder_transforms(transforms, input_dims, dimension_order) | ||
| return ";".join([t._make_transform_spec() for t in transforms]) | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,8 +7,9 @@ | |
| from abc import ABC, abstractmethod | ||
| from dataclasses import dataclass | ||
| from types import ModuleType | ||
| from typing import Sequence | ||
| from typing import Optional, Sequence, Tuple | ||
|
|
||
| import torch | ||
| from torch import nn | ||
|
|
||
|
|
||
|
|
@@ -22,8 +23,8 @@ class DecoderTransform(ABC): | |
| decoded frames and applying the same kind of transform. | ||
|
|
||
| Most ``DecoderTransform`` objects have a complementary transform in TorchVision, | ||
| specificially in `torchvision.transforms.v2 <https://docs.pytorch.org/vision/stable/transforms.html>`_. For such transforms, we | ||
| ensure that: | ||
| specificially in `torchvision.transforms.v2 <https://docs.pytorch.org/vision/stable/transforms.html>`_. | ||
| For such transforms, we ensure that: | ||
|
|
||
| 1. The names are the same. | ||
| 2. Default behaviors are the same. | ||
|
|
@@ -40,6 +41,11 @@ class DecoderTransform(ABC): | |
| def _make_transform_spec(self) -> str: | ||
| pass | ||
|
|
||
| def _get_output_dims( | ||
| self, input_dims: Tuple[Optional[int], Optional[int]] | ||
| ) -> Tuple[Optional[int], Optional[int]]: | ||
| return input_dims | ||
|
|
||
|
|
||
| def import_torchvision_transforms_v2() -> ModuleType: | ||
| try: | ||
|
|
@@ -66,28 +72,167 @@ class Resize(DecoderTransform): | |
| size: Sequence[int] | ||
|
|
||
| def _make_transform_spec(self) -> str: | ||
| # TODO: establish this invariant in the constructor during refactor | ||
| assert len(self.size) == 2 | ||
| return f"resize, {self.size[0]}, {self.size[1]}" | ||
|
|
||
| def _get_output_dims( | ||
| self, input_dims: Tuple[Optional[int], Optional[int]] | ||
| ) -> Tuple[Optional[int], Optional[int]]: | ||
| # TODO: establish this invariant in the constructor during refactor | ||
| assert len(self.size) == 2 | ||
| return (self.size[0], self.size[1]) | ||
|
|
||
| @classmethod | ||
| def _from_torchvision(cls, resize_tv: nn.Module): | ||
| def _from_torchvision(cls, tv_resize: nn.Module): | ||
| v2 = import_torchvision_transforms_v2() | ||
|
|
||
| assert isinstance(resize_tv, v2.Resize) | ||
| assert isinstance(tv_resize, v2.Resize) | ||
|
|
||
| if resize_tv.interpolation is not v2.InterpolationMode.BILINEAR: | ||
| if tv_resize.interpolation is not v2.InterpolationMode.BILINEAR: | ||
| raise ValueError( | ||
| "TorchVision Resize transform must use bilinear interpolation." | ||
| ) | ||
| if resize_tv.antialias is False: | ||
| if tv_resize.antialias is False: | ||
| raise ValueError( | ||
| "TorchVision Resize transform must have antialias enabled." | ||
| ) | ||
| if resize_tv.size is None: | ||
| if tv_resize.size is None: | ||
| raise ValueError("TorchVision Resize transform must have a size specified.") | ||
| if len(resize_tv.size) != 2: | ||
| if len(tv_resize.size) != 2: | ||
| raise ValueError( | ||
| "TorchVision Resize transform must have a (height, width) " | ||
| f"pair for the size, got {resize_tv.size}." | ||
| f"pair for the size, got {tv_resize.size}." | ||
| ) | ||
| return cls(size=tv_resize.size) | ||
|
|
||
|
|
||
| @dataclass | ||
| class RandomCrop(DecoderTransform): | ||
| """Crop the decoded frame to a given size at a random location in the frame. | ||
|
|
||
| Complementary TorchVision transform: :class:`~torchvision.transforms.v2.RandomCrop`. | ||
| Padding of all kinds is disabled. The random location within the frame is | ||
| determined during the initialization of the | ||
| :class:~`torchcodec.decoders.VideoDecoder` object that owns this transform. | ||
| As a consequence, each decoded frame in the video will be cropped at the | ||
| same location. Videos with variable resolution may result in undefined | ||
| behavior. | ||
|
|
||
| Args: | ||
| size: (sequence of int): Desired output size. Must be a sequence of | ||
| the form (height, width). | ||
| """ | ||
|
|
||
| size: Sequence[int] | ||
| _top: Optional[int] = None | ||
| _left: Optional[int] = None | ||
| _input_dims: Optional[Tuple[int, int]] = None | ||
|
|
||
| def _make_transform_spec(self) -> str: | ||
| if len(self.size) != 2: | ||
| raise ValueError( | ||
| f"RandomCrop's size must be a sequence of length 2, got {self.size}. " | ||
| "This should never happen, please report a bug." | ||
| ) | ||
|
|
||
| if self._top is None or self._left is None: | ||
| # TODO: It would be very strange if only ONE of those is None. But should we | ||
| # make it an error? We can continue, but it would probably mean | ||
| # something bad happened. Dear reviewer, please register an opinion here: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree it would appear something bad happened in this case. But when calling this function, do we expect
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It depends on if It has occurred to me that maybe we don't need to call |
||
| if self._input_dims is None: | ||
| raise ValueError( | ||
| "RandomCrop's input_dims must be set before calling _make_transform_spec(). " | ||
| "This should never happen, please report a bug." | ||
| ) | ||
| if self._input_dims[0] < self.size[0] or self._input_dims[1] < self.size[1]: | ||
| raise ValueError( | ||
| f"Input dimensions {self._input_dims} are smaller than the crop size {self.size}." | ||
| ) | ||
|
|
||
| # Note: This logic must match the logic in | ||
| # torchvision.transforms.v2.RandomCrop.make_params(). Given | ||
| # the same seed, they should get the same result. This is an | ||
| # API guarantee with our users. | ||
| self._top = int( | ||
| torch.randint(0, self._input_dims[0] - self.size[0] + 1, size=()).item() | ||
| ) | ||
| self._left = int( | ||
| torch.randint(0, self._input_dims[1] - self.size[1] + 1, size=()).item() | ||
| ) | ||
| return cls(size=resize_tv.size) | ||
|
|
||
| return f"crop, {self.size[0]}, {self.size[1]}, {self._left}, {self._top}" | ||
|
|
||
| def _get_output_dims( | ||
| self, input_dims: Tuple[Optional[int], Optional[int]] | ||
| ) -> Tuple[Optional[int], Optional[int]]: | ||
| # TODO: establish this invariant in the constructor during refactor | ||
| assert len(self.size) == 2 | ||
|
|
||
| height, width = input_dims | ||
| if height is None: | ||
| raise ValueError( | ||
| "Video metadata has no height. RandomCrop can only be used when input frame dimensions are known." | ||
| ) | ||
| if width is None: | ||
| raise ValueError( | ||
| "Video metadata has no width. RandomCrop can only be used when input frame dimensions are known." | ||
| ) | ||
|
|
||
| self._input_dims = (height, width) | ||
| return (self.size[0], self.size[1]) | ||
|
|
||
| @classmethod | ||
| def _from_torchvision( | ||
| cls, | ||
| tv_random_crop: nn.Module, | ||
| input_dims: Tuple[Optional[int], Optional[int]], | ||
| ): | ||
| v2 = import_torchvision_transforms_v2() | ||
|
|
||
| assert isinstance(tv_random_crop, v2.RandomCrop) | ||
|
|
||
| if tv_random_crop.padding is not None: | ||
| raise ValueError( | ||
| "TorchVision RandomCrop transform must not specify padding." | ||
| ) | ||
|
|
||
| if tv_random_crop.pad_if_needed is True: | ||
| raise ValueError( | ||
| "TorchVision RandomCrop transform must not specify pad_if_needed." | ||
| ) | ||
|
|
||
| if tv_random_crop.fill != 0: | ||
| raise ValueError("TorchVision RandomCrop fill must be 0.") | ||
|
|
||
| if tv_random_crop.padding_mode != "constant": | ||
| raise ValueError("TorchVision RandomCrop padding_mode must be constant.") | ||
|
|
||
| if len(tv_random_crop.size) != 2: | ||
| raise ValueError( | ||
| "TorchVision RandcomCrop transform must have a (height, width) " | ||
| f"pair for the size, got {tv_random_crop.size}." | ||
| ) | ||
|
|
||
| height, width = input_dims | ||
| if height is None: | ||
| raise ValueError( | ||
| "Video metadata has no height. RandomCrop can only be used when input frame dimensions are known." | ||
| ) | ||
| if width is None: | ||
| raise ValueError( | ||
| "Video metadata has no width. RandomCrop can only be used when input frame dimensions are known." | ||
| ) | ||
|
|
||
| # Note that TorchVision v2 transforms only accept NCHW tensors. | ||
| params = tv_random_crop.make_params( | ||
| torch.empty(size=(3, height, width), dtype=torch.uint8) | ||
| ) | ||
|
|
||
| if tv_random_crop.size != (params["height"], params["width"]): | ||
| raise ValueError( | ||
| f"TorchVision RandomCrop's provided size, {tv_random_crop.size} " | ||
| f"must match the computed size, {params['height'], params['width']}." | ||
| ) | ||
|
|
||
| return cls(size=tv_random_crop.size, _top=params["top"], _left=params["left"]) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The location (0, 0) is a valid image location. 🤦