|
15 | 15 |
|
16 | 16 | import functools |
17 | 17 | from dataclasses import dataclass |
| 18 | +from typing import Any, Iterable, List, Optional, Tuple |
18 | 19 |
|
19 | 20 | from .image_processing_utils import BaseImageProcessor |
20 | | -from .utils.import_utils import is_torchvision_available |
| 21 | +from .utils.import_utils import is_torch_available, is_torchvision_available |
21 | 22 |
|
22 | 23 |
|
23 | 24 | if is_torchvision_available(): |
24 | 25 | from torchvision.transforms import Compose |
25 | 26 |
|
| 27 | +if is_torch_available(): |
| 28 | + import torch |
| 29 | + |
26 | 30 |
|
27 | 31 | @dataclass(frozen=True) |
28 | 32 | class SizeDict: |
@@ -66,3 +70,64 @@ def to_dict(self): |
66 | 70 | encoder_dict = super().to_dict() |
67 | 71 | encoder_dict.pop("_transform_params", None) |
68 | 72 | return encoder_dict |
| 73 | + |
| 74 | + |
| 75 | +def get_image_size_for_max_height_width( |
| 76 | + image_size: Tuple[int, int], |
| 77 | + max_height: int, |
| 78 | + max_width: int, |
| 79 | +) -> Tuple[int, int]: |
| 80 | + """ |
| 81 | + Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio. |
| 82 | + Important, even if image_height < max_height and image_width < max_width, the image will be resized |
| 83 | + to at least one of the edges be equal to max_height or max_width. |
| 84 | +
|
| 85 | + For example: |
| 86 | + - input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50) |
| 87 | + - input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400) |
| 88 | +
|
| 89 | + Args: |
| 90 | + image_size (`Tuple[int, int]`): |
| 91 | + The image to resize. |
| 92 | + max_height (`int`): |
| 93 | + The maximum allowed height. |
| 94 | + max_width (`int`): |
| 95 | + The maximum allowed width. |
| 96 | + """ |
| 97 | + height, width = image_size |
| 98 | + height_scale = max_height / height |
| 99 | + width_scale = max_width / width |
| 100 | + min_scale = min(height_scale, width_scale) |
| 101 | + new_height = int(height * min_scale) |
| 102 | + new_width = int(width * min_scale) |
| 103 | + return new_height, new_width |
| 104 | + |
| 105 | + |
| 106 | +def safe_squeeze(tensor: "torch.Tensor", axis: Optional[int] = None) -> "torch.Tensor": |
| 107 | + """ |
| 108 | + Squeezes a tensor, but only if the axis specified has dim 1. |
| 109 | + """ |
| 110 | + if axis is None: |
| 111 | + return tensor.squeeze() |
| 112 | + |
| 113 | + try: |
| 114 | + return tensor.squeeze(axis=axis) |
| 115 | + except ValueError: |
| 116 | + return tensor |
| 117 | + |
| 118 | + |
| 119 | +def max_across_indices(values: Iterable[Any]) -> List[Any]: |
| 120 | + """ |
| 121 | + Return the maximum value across all indices of an iterable of values. |
| 122 | + """ |
| 123 | + return [max(values_i) for values_i in zip(*values)] |
| 124 | + |
| 125 | + |
| 126 | +def get_max_height_width(images: List["torch.Tensor"]) -> Tuple[int]: |
| 127 | + """ |
| 128 | + Get the maximum height and width across all images in a batch. |
| 129 | + """ |
| 130 | + |
| 131 | + _, max_height, max_width = max_across_indices([img.shape for img in images]) |
| 132 | + |
| 133 | + return (max_height, max_width) |
0 commit comments