|
9 | 9 | import numbers |
10 | 10 | from dataclasses import dataclass |
11 | 11 | from pathlib import Path |
12 | | -from typing import List, Literal, Optional, Sequence, Tuple, Union |
| 12 | +from typing import Literal, Optional, Sequence, Tuple, Union |
13 | 13 |
|
14 | 14 | import torch |
15 | 15 | from torch import device as torch_device, nn, Tensor |
|
20 | 20 | create_decoder, |
21 | 21 | ERROR_REPORTING_INSTRUCTIONS, |
22 | 22 | ) |
23 | | -from torchcodec.transforms import DecoderTransform, Resize |
| 23 | +from torchcodec.transforms import DecoderTransform |
| 24 | +from torchcodec.transforms._decoder_transforms import _make_transform_specs |
24 | 25 |
|
25 | 26 |
|
26 | 27 | @dataclass |
@@ -217,7 +218,10 @@ def __init__( |
217 | 218 | device = str(device) |
218 | 219 |
|
219 | 220 | device_variant = _get_cuda_backend() |
220 | | - transform_specs = _make_transform_specs(transforms) |
| 221 | + transform_specs = _make_transform_specs( |
| 222 | + transforms, |
| 223 | + input_dims=(self.metadata.height, self.metadata.width), |
| 224 | + ) |
221 | 225 |
|
222 | 226 | core.add_video_stream( |
223 | 227 | self._decoder, |
@@ -523,78 +527,6 @@ def _get_and_validate_stream_metadata( |
523 | 527 | ) |
524 | 528 |
|
525 | 529 |
|
526 | | -def _convert_to_decoder_transforms( |
527 | | - transforms: Sequence[Union[DecoderTransform, nn.Module]], |
528 | | -) -> List[DecoderTransform]: |
529 | | - """Convert a sequence of transforms that may contain TorchVision transform |
530 | | - objects into a list of only TorchCodec transform objects. |
531 | | -
|
532 | | - Args: |
533 | | - transforms: Squence of transform objects. The objects can be one of two |
534 | | - types: |
535 | | - 1. torchcodec.transforms.DecoderTransform |
536 | | - 2. torchvision.transforms.v2.Transform, but our type annotation |
537 | | - only mentions its base, nn.Module. We don't want to take a |
538 | | - hard dependency on TorchVision. |
539 | | -
|
540 | | - Returns: |
541 | | - List of DecoderTransform objects. |
542 | | - """ |
543 | | - try: |
544 | | - from torchvision.transforms import v2 |
545 | | - |
546 | | - tv_available = True |
547 | | - except ImportError: |
548 | | - tv_available = False |
549 | | - |
550 | | - converted_transforms: list[DecoderTransform] = [] |
551 | | - for transform in transforms: |
552 | | - if not isinstance(transform, DecoderTransform): |
553 | | - if not tv_available: |
554 | | - raise ValueError( |
555 | | - f"The supplied transform, {transform}, is not a TorchCodec " |
556 | | - " DecoderTransform. TorchCodec also accept TorchVision " |
557 | | - "v2 transforms, but TorchVision is not installed." |
558 | | - ) |
559 | | - elif isinstance(transform, v2.Resize): |
560 | | - converted_transforms.append(Resize._from_torchvision(transform)) |
561 | | - else: |
562 | | - raise ValueError( |
563 | | - f"Unsupported transform: {transform}. Transforms must be " |
564 | | - "either a TorchCodec DecoderTransform or a TorchVision " |
565 | | - "v2 transform." |
566 | | - ) |
567 | | - else: |
568 | | - converted_transforms.append(transform) |
569 | | - |
570 | | - return converted_transforms |
571 | | - |
572 | | - |
573 | | -def _make_transform_specs( |
574 | | - transforms: Optional[Sequence[Union[DecoderTransform, nn.Module]]], |
575 | | -) -> str: |
576 | | - """Given a sequence of transforms, turn those into the specification string |
577 | | - the core API expects. |
578 | | -
|
579 | | - Args: |
580 | | - transforms: Optional sequence of transform objects. The objects can be |
581 | | - one of two types: |
582 | | - 1. torchcodec.transforms.DecoderTransform |
583 | | - 2. torchvision.transforms.v2.Transform, but our type annotation |
584 | | - only mentions its base, nn.Module. We don't want to take a |
585 | | - hard dependency on TorchVision. |
586 | | -
|
587 | | - Returns: |
588 | | - String of transforms in the format the core API expects: transform |
589 | | - specifications separate by semicolons. |
590 | | - """ |
591 | | - if transforms is None: |
592 | | - return "" |
593 | | - |
594 | | - transforms = _convert_to_decoder_transforms(transforms) |
595 | | - return ";".join([t._make_transform_spec() for t in transforms]) |
596 | | - |
597 | | - |
598 | 530 | def _read_custom_frame_mappings( |
599 | 531 | custom_frame_mappings: Union[str, bytes, io.RawIOBase, io.BufferedReader] |
600 | 532 | ) -> tuple[Tensor, Tensor, Tensor]: |
|
0 commit comments