|
8 | 8 | import json |
9 | 9 | import numbers |
10 | 10 | from pathlib import Path |
11 | | -from typing import List, Literal, Optional, Sequence, Tuple, Union |
| 11 | +from typing import Literal, Optional, Sequence, Tuple, Union |
12 | 12 |
|
13 | 13 | import torch |
14 | 14 | from torch import device as torch_device, nn, Tensor |
@@ -170,7 +170,6 @@ def __init__( |
170 | 170 | transform_specs = _make_transform_specs( |
171 | 171 | transforms, |
172 | 172 | input_dims=(self.metadata.height, self.metadata.width), |
173 | | - dimension_order=dimension_order, |
174 | 173 | ) |
175 | 174 |
|
176 | 175 | core.add_video_stream( |
@@ -452,96 +451,97 @@ def _get_and_validate_stream_metadata( |
452 | 451 | ) |
453 | 452 |
|
454 | 453 |
|
455 | | -def _convert_to_decoder_transforms( |
456 | | - transforms: Sequence[Union[DecoderTransform, nn.Module]], |
| 454 | +def _make_transform_specs( |
| 455 | + transforms: Optional[Sequence[Union[DecoderTransform, nn.Module]]], |
457 | 456 | input_dims: Tuple[Optional[int], Optional[int]], |
458 | | - dimension_order: Literal["NCHW", "NHWC"], |
459 | | -) -> List[DecoderTransform]: |
460 | | - """Convert a sequence of transforms that may contain TorchVision transform |
461 | | - objects into a list of only TorchCodec transform objects. |
| 457 | +) -> str: |
| 458 | + """Given a sequence of transforms, turn those into the specification string |
| 459 | + the core API expects. |
462 | 460 |
|
463 | 461 | Args: |
464 | | - transforms: Squence of transform objects. The objects can be one of two |
465 | | - types: |
| 462 | + transforms: Optional sequence of transform objects. The objects can be |
| 463 | + one of two types: |
466 | 464 | 1. torchcodec.transforms.DecoderTransform |
467 | 465 | 2. torchvision.transforms.v2.Transform, but our type annotation |
468 | 466 | only mentions its base, nn.Module. We don't want to take a |
469 | 467 | hard dependency on TorchVision. |
| 468 | + input_dims: Optional (height, width) pair. Note that only some |
| 469 | + transforms need to know the dimensions. If the user provides |
| 470 | + transforms that don't need to know the dimensions, and that metadata |
| 471 | + is missing, everything should still work. That means we assert their |
| 472 | + existence as late as possible. |
470 | 473 |
|
471 | 474 | Returns: |
472 | | - List of DecoderTransform objects. |
| 475 | + String of transforms in the format the core API expects: transform |
| 476 | + specifications separate by semicolons. |
473 | 477 | """ |
| 478 | + if transforms is None: |
| 479 | + return "" |
| 480 | + |
474 | 481 | try: |
475 | 482 | from torchvision.transforms import v2 |
476 | 483 |
|
477 | 484 | tv_available = True |
478 | 485 | except ImportError: |
479 | 486 | tv_available = False |
480 | 487 |
|
481 | | - converted_transforms: list[DecoderTransform] = [] |
| 488 | + # The following loop accomplishes two tasks: |
| 489 | + # |
| 490 | + # 1. Converts the transform to a DecoderTransform, if necessary. We |
| 491 | + # accept TorchVision transform objects and they must be converted |
| 492 | + # to their matching DecoderTransform. |
| 493 | + # 2. Calculates what the input dimensions are to each transform. |
| 494 | + # |
| 495 | + # The order in our transforms list is semantically meaningful, as we |
| 496 | + # actually have a pipeline where the output of one transform is the input to |
| 497 | + # the next. For example, if we have the transforms list [A, B, C, D], then |
| 498 | + # we should understand that as: |
| 499 | + # A -> B -> C -> D |
| 500 | + # Where the frame produced by A is the input to B, the frame produced by B |
| 501 | + # is the input to C, etc. This particularly matters for frame dimensions. |
| 502 | + # Transforms can both: |
| 503 | + # |
| 504 | + # 1. Produce frames with arbitrary dimensions. |
| 505 | + # 2. Rely on their input frame's dimensions to calculate ahead-of-time |
| 506 | + # what their runtime behavior will be. |
| 507 | + # |
| 508 | + # The consequence of the above facts is that we need to statically track |
| 509 | + # frame dimensions in the pipeline while we pre-process it. The input |
| 510 | + # frame's dimensions to A, our first transform, is always what we know from |
| 511 | + # our metadata. For each transform, we always calculate its output |
| 512 | + # dimensions from its input dimensions. We store these with the converted |
| 513 | + # transform, to be all used together when we generate the specs. |
| 514 | + converted_transforms: list[(DecoderTransform, Tuple[int, int])] = [] |
| 515 | + curr_input_dims = input_dims |
482 | 516 | for transform in transforms: |
483 | | - if not isinstance(transform, DecoderTransform): |
| 517 | + if isinstance(transform, DecoderTransform): |
| 518 | + output_dims = transform._calculate_output_dims(curr_input_dims) |
| 519 | + converted_transforms.append((transform, curr_input_dims)) |
| 520 | + else: |
484 | 521 | if not tv_available: |
485 | 522 | raise ValueError( |
486 | 523 | f"The supplied transform, {transform}, is not a TorchCodec " |
487 | | - " DecoderTransform. TorchCodec also accept TorchVision " |
| 524 | + " DecoderTransform. TorchCodec also accepts TorchVision " |
488 | 525 | "v2 transforms, but TorchVision is not installed." |
489 | 526 | ) |
490 | 527 | elif isinstance(transform, v2.Resize): |
491 | | - transform_tc = Resize._from_torchvision(transform) |
492 | | - input_dims = transform_tc._get_output_dims(input_dims) |
493 | | - converted_transforms.append(transform_tc) |
| 528 | + tc_transform = Resize._from_torchvision(transform) |
| 529 | + output_dims = tc_transform._calculate_output_dims(curr_input_dims) |
| 530 | + converted_transforms.append((tc_transform, curr_input_dims)) |
494 | 531 | elif isinstance(transform, v2.RandomCrop): |
495 | | - if dimension_order != "NCHW": |
496 | | - raise ValueError( |
497 | | - "TorchVision v2 RandomCrop is only supported for NCHW " |
498 | | - "dimension order. Please use the TorchCodec RandomCrop " |
499 | | - "transform instead." |
500 | | - ) |
501 | | - transform_tc = RandomCrop._from_torchvision( |
502 | | - transform, |
503 | | - input_dims, |
504 | | - ) |
505 | | - input_dims = transform_tc._get_output_dims(input_dims) |
506 | | - converted_transforms.append(transform_tc) |
| 532 | + tc_transform = RandomCrop._from_torchvision(transform) |
| 533 | + output_dims = tc_transform._calculate_output_dims(curr_input_dims) |
| 534 | + converted_transforms.append((tc_transform, curr_input_dims)) |
507 | 535 | else: |
508 | 536 | raise ValueError( |
509 | 537 | f"Unsupported transform: {transform}. Transforms must be " |
510 | 538 | "either a TorchCodec DecoderTransform or a TorchVision " |
511 | 539 | "v2 transform." |
512 | 540 | ) |
513 | | - else: |
514 | | - input_dims = transform._get_output_dims(input_dims) |
515 | | - converted_transforms.append(transform) |
516 | | - |
517 | | - return converted_transforms |
518 | | - |
519 | 541 |
|
520 | | -def _make_transform_specs( |
521 | | - transforms: Optional[Sequence[Union[DecoderTransform, nn.Module]]], |
522 | | - input_dims: Tuple[Optional[int], Optional[int]], |
523 | | - dimension_order: Literal["NCHW", "NHWC"], |
524 | | -) -> str: |
525 | | - """Given a sequence of transforms, turn those into the specification string |
526 | | - the core API expects. |
527 | | -
|
528 | | - Args: |
529 | | - transforms: Optional sequence of transform objects. The objects can be |
530 | | - one of two types: |
531 | | - 1. torchcodec.transforms.DecoderTransform |
532 | | - 2. torchvision.transforms.v2.Transform, but our type annotation |
533 | | - only mentions its base, nn.Module. We don't want to take a |
534 | | - hard dependency on TorchVision. |
535 | | -
|
536 | | - Returns: |
537 | | - String of transforms in the format the core API expects: transform |
538 | | - specifications separate by semicolons. |
539 | | - """ |
540 | | - if transforms is None: |
541 | | - return "" |
| 542 | + curr_input_dims = output_dims |
542 | 543 |
|
543 | | - transforms = _convert_to_decoder_transforms(transforms, input_dims, dimension_order) |
544 | | - return ";".join([t._make_transform_spec() for t in transforms]) |
| 544 | + return ";".join([t._make_transform_spec(dims) for t, dims in converted_transforms]) |
545 | 545 |
|
546 | 546 |
|
547 | 547 | def _read_custom_frame_mappings( |
|
0 commit comments