|
| 1 | +import math |
| 2 | +import numbers |
| 3 | +import random |
| 4 | +import warnings |
| 5 | +from typing import List, Sequence |
| 6 | + |
1 | 7 | import torch |
2 | 8 | import torchvision.transforms.functional as F |
3 | 9 | try: |
|
6 | 12 | except ImportError: |
7 | 13 | has_interpolation_mode = False |
8 | 14 | from PIL import Image |
9 | | -import warnings |
10 | | -import math |
11 | | -import random |
12 | 15 | import numpy as np |
13 | 16 |
|
14 | 17 |
|
@@ -96,6 +99,19 @@ def interp_mode_to_str(mode): |
96 | 99 | _RANDOM_INTERPOLATION = (str_to_interp_mode('bilinear'), str_to_interp_mode('bicubic')) |
97 | 100 |
|
98 | 101 |
|
| 102 | +def _setup_size(size, error_msg): |
| 103 | + if isinstance(size, numbers.Number): |
| 104 | + return int(size), int(size) |
| 105 | + |
| 106 | + if isinstance(size, Sequence) and len(size) == 1: |
| 107 | + return size[0], size[0] |
| 108 | + |
| 109 | + if len(size) != 2: |
| 110 | + raise ValueError(error_msg) |
| 111 | + |
| 112 | + return size |
| 113 | + |
| 114 | + |
99 | 115 | class RandomResizedCropAndInterpolation: |
100 | 116 | """Crop the given PIL Image to random size and aspect ratio with random interpolation. |
101 | 117 |
|
@@ -195,3 +211,132 @@ def __repr__(self): |
195 | 211 | format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) |
196 | 212 | format_string += ', interpolation={0})'.format(interpolate_str) |
197 | 213 | return format_string |
| 214 | + |
| 215 | + |
| 216 | +def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor: |
| 217 | + """Center crops and/or pads the given image. |
| 218 | + If the image is torch Tensor, it is expected |
| 219 | + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. |
| 220 | + If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. |
| 221 | +
|
| 222 | + Args: |
| 223 | + img (PIL Image or Tensor): Image to be cropped. |
| 224 | + output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int, |
| 225 | + it is used for both directions. |
| 226 | + fill (int, Tuple[int]): Padding color |
| 227 | +
|
| 228 | + Returns: |
| 229 | + PIL Image or Tensor: Cropped image. |
| 230 | + """ |
| 231 | + if isinstance(output_size, numbers.Number): |
| 232 | + output_size = (int(output_size), int(output_size)) |
| 233 | + elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: |
| 234 | + output_size = (output_size[0], output_size[0]) |
| 235 | + |
| 236 | + _, image_height, image_width = F.get_dimensions(img) |
| 237 | + crop_height, crop_width = output_size |
| 238 | + |
| 239 | + if crop_width > image_width or crop_height > image_height: |
| 240 | + padding_ltrb = [ |
| 241 | + (crop_width - image_width) // 2 if crop_width > image_width else 0, |
| 242 | + (crop_height - image_height) // 2 if crop_height > image_height else 0, |
| 243 | + (crop_width - image_width + 1) // 2 if crop_width > image_width else 0, |
| 244 | + (crop_height - image_height + 1) // 2 if crop_height > image_height else 0, |
| 245 | + ] |
| 246 | + img = F.pad(img, padding_ltrb, fill=fill) |
| 247 | + _, image_height, image_width = F.get_dimensions(img) |
| 248 | + if crop_width == image_width and crop_height == image_height: |
| 249 | + return img |
| 250 | + |
| 251 | + crop_top = int(round((image_height - crop_height) / 2.0)) |
| 252 | + crop_left = int(round((image_width - crop_width) / 2.0)) |
| 253 | + return F.crop(img, crop_top, crop_left, crop_height, crop_width) |
| 254 | + |
| 255 | + |
| 256 | +class CenterCropOrPad(torch.nn.Module): |
| 257 | + """Crops the given image at the center. |
| 258 | + If the image is torch Tensor, it is expected |
| 259 | + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. |
| 260 | + If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. |
| 261 | +
|
| 262 | + Args: |
| 263 | + size (sequence or int): Desired output size of the crop. If size is an |
| 264 | + int instead of sequence like (h, w), a square crop (size, size) is |
| 265 | + made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). |
| 266 | + """ |
| 267 | + |
| 268 | + def __init__(self, size, fill=0): |
| 269 | + super().__init__() |
| 270 | + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") |
| 271 | + self.fill = fill |
| 272 | + |
| 273 | + def forward(self, img): |
| 274 | + """ |
| 275 | + Args: |
| 276 | + img (PIL Image or Tensor): Image to be cropped. |
| 277 | +
|
| 278 | + Returns: |
| 279 | + PIL Image or Tensor: Cropped image. |
| 280 | + """ |
| 281 | + return center_crop_or_pad(img, self.size, fill=self.fill) |
| 282 | + |
| 283 | + def __repr__(self) -> str: |
| 284 | + return f"{self.__class__.__name__}(size={self.size})" |
| 285 | + |
| 286 | + |
| 287 | +class ResizeKeepRatio: |
| 288 | + """ Resize and Keep Ratio |
| 289 | + """ |
| 290 | + |
| 291 | + def __init__( |
| 292 | + self, |
| 293 | + size, |
| 294 | + longest=0., |
| 295 | + interpolation='bilinear', |
| 296 | + fill=0, |
| 297 | + ): |
| 298 | + if isinstance(size, (list, tuple)): |
| 299 | + self.size = tuple(size) |
| 300 | + else: |
| 301 | + self.size = (size, size) |
| 302 | + self.interpolation = str_to_interp_mode(interpolation) |
| 303 | + self.longest = float(longest) |
| 304 | + self.fill = fill |
| 305 | + |
| 306 | + @staticmethod |
| 307 | + def get_params(img, target_size, longest): |
| 308 | + """Get parameters |
| 309 | +
|
| 310 | + Args: |
| 311 | + img (PIL Image): Image to be cropped. |
| 312 | + target_size (Tuple[int, int]): Size of output |
| 313 | + Returns: |
| 314 | + tuple: params (h, w) and (l, r, t, b) to be passed to ``resize`` and ``pad`` respectively |
| 315 | + """ |
| 316 | + source_size = img.size[::-1] # h, w |
| 317 | + h, w = source_size |
| 318 | + target_h, target_w = target_size |
| 319 | + ratio_h = h / target_h |
| 320 | + ratio_w = w / target_w |
| 321 | + ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest) |
| 322 | + size = [round(x / ratio) for x in source_size] |
| 323 | + return size |
| 324 | + |
| 325 | + def __call__(self, img): |
| 326 | + """ |
| 327 | + Args: |
| 328 | + img (PIL Image): Image to be cropped and resized. |
| 329 | +
|
| 330 | + Returns: |
| 331 | + PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size |
| 332 | + """ |
| 333 | + size = self.get_params(img, self.size, self.longest) |
| 334 | + img = F.resize(img, size, self.interpolation) |
| 335 | + return img |
| 336 | + |
| 337 | + def __repr__(self): |
| 338 | + interpolate_str = interp_mode_to_str(self.interpolation) |
| 339 | + format_string = self.__class__.__name__ + '(size={0}'.format(self.size) |
| 340 | + format_string += f', interpolation={interpolate_str})' |
| 341 | + format_string += f', longest={self.longest:.3f})' |
| 342 | + return format_string |
0 commit comments