diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 776034256..9c2727bad 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -19,7 +19,8 @@ create_decoder, ERROR_REPORTING_INSTRUCTIONS, ) -from torchcodec.transforms import CenterCrop, DecoderTransform, RandomCrop, Resize +from torchcodec.transforms import DecoderTransform +from torchcodec.transforms._decoder_transforms import _make_transform_specs class VideoDecoder: @@ -451,104 +452,6 @@ def _get_and_validate_stream_metadata( ) -def _make_transform_specs( - transforms: Optional[Sequence[Union[DecoderTransform, nn.Module]]], - input_dims: Tuple[Optional[int], Optional[int]], -) -> str: - """Given a sequence of transforms, turn those into the specification string - the core API expects. - - Args: - transforms: Optional sequence of transform objects. The objects can be - one of two types: - 1. torchcodec.transforms.DecoderTransform - 2. torchvision.transforms.v2.Transform, but our type annotation - only mentions its base, nn.Module. We don't want to take a - hard dependency on TorchVision. - input_dims: Optional (height, width) pair. Note that only some - transforms need to know the dimensions. If the user provides - transforms that don't need to know the dimensions, and that metadata - is missing, everything should still work. That means we assert their - existence as late as possible. - - Returns: - String of transforms in the format the core API expects: transform - specifications separate by semicolons. - """ - if transforms is None: - return "" - - try: - from torchvision.transforms import v2 - - tv_available = True - except ImportError: - tv_available = False - - # The following loop accomplishes two tasks: - # - # 1. Converts the transform to a DecoderTransform, if necessary. We - # accept TorchVision transform objects and they must be converted - # to their matching DecoderTransform. - # 2. Calculates what the input dimensions are to each transform. - # - # The order in our transforms list is semantically meaningful, as we - # actually have a pipeline where the output of one transform is the input to - # the next. For example, if we have the transforms list [A, B, C, D], then - # we should understand that as: - # - # A -> B -> C -> D - # - # Where the frame produced by A is the input to B, the frame produced by B - # is the input to C, etc. This particularly matters for frame dimensions. - # Transforms can both: - # - # 1. Produce frames with arbitrary dimensions. - # 2. Rely on their input frame's dimensions to calculate ahead-of-time - # what their runtime behavior will be. - # - # The consequence of the above facts is that we need to statically track - # frame dimensions in the pipeline while we pre-process it. The input - # frame's dimensions to A, our first transform, is always what we know from - # our metadata. For each transform, we always calculate its output - # dimensions from its input dimensions. We store these with the converted - # transform, to be all used together when we generate the specs. - converted_transforms: list[ - Tuple[ - DecoderTransform, - # A (height, width) pair where the values may be missing. - Tuple[Optional[int], Optional[int]], - ] - ] = [] - curr_input_dims = input_dims - for transform in transforms: - if not isinstance(transform, DecoderTransform): - if not tv_available: - raise ValueError( - f"The supplied transform, {transform}, is not a TorchCodec " - " DecoderTransform. TorchCodec also accepts TorchVision " - "v2 transforms, but TorchVision is not installed." - ) - elif isinstance(transform, v2.Resize): - transform = Resize._from_torchvision(transform) - elif isinstance(transform, v2.CenterCrop): - transform = CenterCrop._from_torchvision(transform) - elif isinstance(transform, v2.RandomCrop): - transform = RandomCrop._from_torchvision(transform) - else: - raise ValueError( - f"Unsupported transform: {transform}. Transforms must be " - "either a TorchCodec DecoderTransform or a TorchVision " - "v2 transform." - ) - - converted_transforms.append((transform, curr_input_dims)) - output_dims = transform._get_output_dims() - curr_input_dims = output_dims if output_dims is not None else curr_input_dims - - return ";".join([t._make_transform_spec(dims) for t, dims in converted_transforms]) - - def _read_custom_frame_mappings( custom_frame_mappings: Union[str, bytes, io.RawIOBase, io.BufferedReader] ) -> tuple[Tensor, Tensor, Tensor]: diff --git a/src/torchcodec/transforms/_decoder_transforms.py b/src/torchcodec/transforms/_decoder_transforms.py index 7b6d3d1c7..6a99dd800 100644 --- a/src/torchcodec/transforms/_decoder_transforms.py +++ b/src/torchcodec/transforms/_decoder_transforms.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod from types import ModuleType -from typing import Optional, Sequence, Tuple +from typing import Optional, Sequence, Tuple, Union import torch from torch import nn @@ -282,3 +282,101 @@ def _from_torchvision( ) return cls(size=tv_random_crop.size) + + +def _make_transform_specs( + transforms: Optional[Sequence[Union[DecoderTransform, nn.Module]]], + input_dims: Tuple[Optional[int], Optional[int]], +) -> str: + """Given a sequence of transforms, turn those into the specification string + the core API expects. + + Args: + transforms: Optional sequence of transform objects. The objects can be + one of two types: + 1. torchcodec.transforms.DecoderTransform + 2. torchvision.transforms.v2.Transform, but our type annotation + only mentions its base, nn.Module. We don't want to take a + hard dependency on TorchVision. + input_dims: Optional (height, width) pair. Note that only some + transforms need to know the dimensions. If the user provides + transforms that don't need to know the dimensions, and that metadata + is missing, everything should still work. That means we assert their + existence as late as possible. + + Returns: + String of transforms in the format the core API expects: transform + specifications separate by semicolons. + """ + if transforms is None: + return "" + + try: + from torchvision.transforms import v2 + + tv_available = True + except ImportError: + tv_available = False + + # The following loop accomplishes two tasks: + # + # 1. Converts the transform to a DecoderTransform, if necessary. We + # accept TorchVision transform objects and they must be converted + # to their matching DecoderTransform. + # 2. Calculates what the input dimensions are to each transform. + # + # The order in our transforms list is semantically meaningful, as we + # actually have a pipeline where the output of one transform is the input to + # the next. For example, if we have the transforms list [A, B, C, D], then + # we should understand that as: + # + # A -> B -> C -> D + # + # Where the frame produced by A is the input to B, the frame produced by B + # is the input to C, etc. This particularly matters for frame dimensions. + # Transforms can both: + # + # 1. Produce frames with arbitrary dimensions. + # 2. Rely on their input frame's dimensions to calculate ahead-of-time + # what their runtime behavior will be. + # + # The consequence of the above facts is that we need to statically track + # frame dimensions in the pipeline while we pre-process it. The input + # frame's dimensions to A, our first transform, is always what we know from + # our metadata. For each transform, we always calculate its output + # dimensions from its input dimensions. We store these with the converted + # transform, to be all used together when we generate the specs. + converted_transforms: list[ + Tuple[ + DecoderTransform, + # A (height, width) pair where the values may be missing. + Tuple[Optional[int], Optional[int]], + ] + ] = [] + curr_input_dims = input_dims + for transform in transforms: + if not isinstance(transform, DecoderTransform): + if not tv_available: + raise ValueError( + f"The supplied transform, {transform}, is not a TorchCodec " + " DecoderTransform. TorchCodec also accepts TorchVision " + "v2 transforms, but TorchVision is not installed." + ) + elif isinstance(transform, v2.Resize): + transform = Resize._from_torchvision(transform) + elif isinstance(transform, v2.CenterCrop): + transform = CenterCrop._from_torchvision(transform) + elif isinstance(transform, v2.RandomCrop): + transform = RandomCrop._from_torchvision(transform) + else: + raise ValueError( + f"Unsupported transform: {transform}. Transforms must be " + "either a TorchCodec DecoderTransform or a TorchVision " + "v2 transform." + ) + + converted_transforms.append((transform, curr_input_dims)) + output_dims = transform._get_output_dims() + curr_input_dims = output_dims if output_dims is not None else curr_input_dims + + return ";".join([t._make_transform_spec(dims) for t, dims in converted_transforms])