|
8 | 8 | import json |
9 | 9 | import numbers |
10 | 10 | from pathlib import Path |
11 | | -from typing import List, Literal, Optional, Sequence, Tuple, Union |
| 11 | +from typing import Literal, Optional, Sequence, Tuple, Union |
12 | 12 |
|
13 | 13 | import torch |
14 | 14 | from torch import device as torch_device, nn, Tensor |
|
19 | 19 | create_decoder, |
20 | 20 | ERROR_REPORTING_INSTRUCTIONS, |
21 | 21 | ) |
22 | | -from torchcodec.transforms import DecoderTransform, Resize |
| 22 | +from torchcodec.transforms import DecoderTransform, RandomCrop, Resize |
23 | 23 |
|
24 | 24 |
|
25 | 25 | class VideoDecoder: |
@@ -167,7 +167,10 @@ def __init__( |
167 | 167 | device = str(device) |
168 | 168 |
|
169 | 169 | device_variant = _get_cuda_backend() |
170 | | - transform_specs = _make_transform_specs(transforms) |
| 170 | + transform_specs = _make_transform_specs( |
| 171 | + transforms, |
| 172 | + input_dims=(self.metadata.height, self.metadata.width), |
| 173 | + ) |
171 | 174 |
|
172 | 175 | core.add_video_stream( |
173 | 176 | self._decoder, |
@@ -448,76 +451,100 @@ def _get_and_validate_stream_metadata( |
448 | 451 | ) |
449 | 452 |
|
450 | 453 |
|
451 | | -def _convert_to_decoder_transforms( |
452 | | - transforms: Sequence[Union[DecoderTransform, nn.Module]], |
453 | | -) -> List[DecoderTransform]: |
454 | | - """Convert a sequence of transforms that may contain TorchVision transform |
455 | | - objects into a list of only TorchCodec transform objects. |
| 454 | +def _make_transform_specs( |
| 455 | + transforms: Optional[Sequence[Union[DecoderTransform, nn.Module]]], |
| 456 | + input_dims: Tuple[Optional[int], Optional[int]], |
| 457 | +) -> str: |
| 458 | + """Given a sequence of transforms, turn those into the specification string |
| 459 | + the core API expects. |
456 | 460 |
|
457 | 461 | Args: |
458 | | - transforms: Squence of transform objects. The objects can be one of two |
459 | | - types: |
| 462 | + transforms: Optional sequence of transform objects. The objects can be |
| 463 | + one of two types: |
460 | 464 | 1. torchcodec.transforms.DecoderTransform |
461 | 465 | 2. torchvision.transforms.v2.Transform, but our type annotation |
462 | 466 | only mentions its base, nn.Module. We don't want to take a |
463 | 467 | hard dependency on TorchVision. |
| 468 | + input_dims: Optional (height, width) pair. Note that only some |
| 469 | + transforms need to know the dimensions. If the user provides |
| 470 | + transforms that don't need to know the dimensions, and that metadata |
| 471 | + is missing, everything should still work. That means we assert their |
| 472 | + existence as late as possible. |
464 | 473 |
|
465 | 474 | Returns: |
466 | | - List of DecoderTransform objects. |
| 475 | + String of transforms in the format the core API expects: transform |
| 476 | + specifications separate by semicolons. |
467 | 477 | """ |
| 478 | + if transforms is None: |
| 479 | + return "" |
| 480 | + |
468 | 481 | try: |
469 | 482 | from torchvision.transforms import v2 |
470 | 483 |
|
471 | 484 | tv_available = True |
472 | 485 | except ImportError: |
473 | 486 | tv_available = False |
474 | 487 |
|
475 | | - converted_transforms: list[DecoderTransform] = [] |
| 488 | + # The following loop accomplishes two tasks: |
| 489 | + # |
| 490 | + # 1. Converts the transform to a DecoderTransform, if necessary. We |
| 491 | + # accept TorchVision transform objects and they must be converted |
| 492 | + # to their matching DecoderTransform. |
| 493 | + # 2. Calculates what the input dimensions are to each transform. |
| 494 | + # |
| 495 | + # The order in our transforms list is semantically meaningful, as we |
| 496 | + # actually have a pipeline where the output of one transform is the input to |
| 497 | + # the next. For example, if we have the transforms list [A, B, C, D], then |
| 498 | + # we should understand that as: |
| 499 | + # |
| 500 | + # A -> B -> C -> D |
| 501 | + # |
| 502 | + # Where the frame produced by A is the input to B, the frame produced by B |
| 503 | + # is the input to C, etc. This particularly matters for frame dimensions. |
| 504 | + # Transforms can both: |
| 505 | + # |
| 506 | + # 1. Produce frames with arbitrary dimensions. |
| 507 | + # 2. Rely on their input frame's dimensions to calculate ahead-of-time |
| 508 | + # what their runtime behavior will be. |
| 509 | + # |
| 510 | + # The consequence of the above facts is that we need to statically track |
| 511 | + # frame dimensions in the pipeline while we pre-process it. The input |
| 512 | + # frame's dimensions to A, our first transform, is always what we know from |
| 513 | + # our metadata. For each transform, we always calculate its output |
| 514 | + # dimensions from its input dimensions. We store these with the converted |
| 515 | + # transform, to be all used together when we generate the specs. |
| 516 | + converted_transforms: list[ |
| 517 | + Tuple[ |
| 518 | + DecoderTransform, |
| 519 | + # A (height, width) pair where the values may be missing. |
| 520 | + Tuple[Optional[int], Optional[int]], |
| 521 | + ] |
| 522 | + ] = [] |
| 523 | + curr_input_dims = input_dims |
476 | 524 | for transform in transforms: |
477 | 525 | if not isinstance(transform, DecoderTransform): |
478 | 526 | if not tv_available: |
479 | 527 | raise ValueError( |
480 | 528 | f"The supplied transform, {transform}, is not a TorchCodec " |
481 | | - " DecoderTransform. TorchCodec also accept TorchVision " |
| 529 | + " DecoderTransform. TorchCodec also accepts TorchVision " |
482 | 530 | "v2 transforms, but TorchVision is not installed." |
483 | 531 | ) |
484 | 532 | elif isinstance(transform, v2.Resize): |
485 | | - converted_transforms.append(Resize._from_torchvision(transform)) |
| 533 | + transform = Resize._from_torchvision(transform) |
| 534 | + elif isinstance(transform, v2.RandomCrop): |
| 535 | + transform = RandomCrop._from_torchvision(transform) |
486 | 536 | else: |
487 | 537 | raise ValueError( |
488 | 538 | f"Unsupported transform: {transform}. Transforms must be " |
489 | 539 | "either a TorchCodec DecoderTransform or a TorchVision " |
490 | 540 | "v2 transform." |
491 | 541 | ) |
492 | | - else: |
493 | | - converted_transforms.append(transform) |
494 | | - |
495 | | - return converted_transforms |
496 | 542 |
|
| 543 | + converted_transforms.append((transform, curr_input_dims)) |
| 544 | + output_dims = transform._get_output_dims() |
| 545 | + curr_input_dims = output_dims if output_dims is not None else curr_input_dims |
497 | 546 |
|
498 | | -def _make_transform_specs( |
499 | | - transforms: Optional[Sequence[Union[DecoderTransform, nn.Module]]], |
500 | | -) -> str: |
501 | | - """Given a sequence of transforms, turn those into the specification string |
502 | | - the core API expects. |
503 | | -
|
504 | | - Args: |
505 | | - transforms: Optional sequence of transform objects. The objects can be |
506 | | - one of two types: |
507 | | - 1. torchcodec.transforms.DecoderTransform |
508 | | - 2. torchvision.transforms.v2.Transform, but our type annotation |
509 | | - only mentions its base, nn.Module. We don't want to take a |
510 | | - hard dependency on TorchVision. |
511 | | -
|
512 | | - Returns: |
513 | | - String of transforms in the format the core API expects: transform |
514 | | - specifications separate by semicolons. |
515 | | - """ |
516 | | - if transforms is None: |
517 | | - return "" |
518 | | - |
519 | | - transforms = _convert_to_decoder_transforms(transforms) |
520 | | - return ";".join([t._make_transform_spec() for t in transforms]) |
| 547 | + return ";".join([t._make_transform_spec(dims) for t, dims in converted_transforms]) |
521 | 548 |
|
522 | 549 |
|
523 | 550 | def _read_custom_frame_mappings( |
|
0 commit comments