Skip to content

Commit f1b6a50

Browse files
authored
Merge pull request #2138 from f4str/object-detector-resize
Implement Square Pad and Resize Preprocessors
2 parents 43d13fc + 7a6eac3 commit f1b6a50

File tree

14 files changed

+1438
-2
lines changed

14 files changed

+1438
-2
lines changed

art/defences/preprocessor/preprocessor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def apply_predict(self) -> bool:
8080
return self._apply_predict
8181

8282
@abc.abstractmethod
83-
def __call__(self, x: np.ndarray, y: Optional[np.ndarray] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]:
83+
def __call__(self, x: np.ndarray, y: Optional[Any] = None) -> Tuple[np.ndarray, Optional[Any]]:
8484
"""
8585
Perform data preprocessing and return preprocessed data as tuple.
8686
@@ -250,7 +250,7 @@ class PreprocessorTensorFlowV2(Preprocessor):
250250
"""
251251

252252
@abc.abstractmethod
253-
def forward(self, x: "tf.Tensor", y: Optional["tf.Tensor"] = None) -> Tuple["tf.Tensor", Optional["tf.Tensor"]]:
253+
def forward(self, x: "tf.Tensor", y: Optional[Any] = None) -> Tuple["tf.Tensor", Optional[Any]]:
254254
"""
255255
Perform data preprocessing in TensorFlow v2 and return preprocessed data as tuple.
256256
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""
2+
This module contains image preprocessing tools.
3+
"""
4+
from art.preprocessing.image.image_resize.numpy import ImageResize
5+
from art.preprocessing.image.image_resize.pytorch import ImageResizePyTorch
6+
from art.preprocessing.image.image_resize.tensorflow import ImageResizeTensorFlowV2
7+
from art.preprocessing.image.image_square_pad.numpy import ImageSquarePad
8+
from art.preprocessing.image.image_square_pad.pytorch import ImageSquarePadPyTorch
9+
from art.preprocessing.image.image_square_pad.tensorflow import ImageSquarePadTensorFlowV2

art/preprocessing/image/image_resize/__init__.py

Whitespace-only changes.
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# MIT License
2+
#
3+
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2023
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
6+
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
7+
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
8+
# persons to whom the Software is furnished to do so, subject to the following conditions:
9+
#
10+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
11+
# Software.
12+
#
13+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
14+
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
15+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
16+
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
17+
# SOFTWARE.
18+
"""
19+
This module implements resizing for images and object detection bounding boxes.
20+
"""
21+
import logging
22+
from typing import Dict, List, Optional, TYPE_CHECKING, Tuple, Union
23+
24+
import numpy as np
25+
import cv2
26+
from tqdm.auto import tqdm
27+
28+
from art.preprocessing.preprocessing import Preprocessor
29+
30+
if TYPE_CHECKING:
31+
from art.utils import CLIP_VALUES_TYPE
32+
33+
logger = logging.getLogger(__name__)
34+
35+
36+
class ImageResize(Preprocessor):
37+
"""
38+
This module implements resizing for images and object detection bounding boxes.
39+
"""
40+
41+
params = ["height", "width", "channels_first", "label_type", "interpolation", "clip_values", "verbose"]
42+
43+
label_types = ["classification", "object_detection"]
44+
45+
def __init__(
46+
self,
47+
height: int,
48+
width: int,
49+
channels_first: bool = False,
50+
label_type: str = "classification",
51+
interpolation: int = cv2.INTER_LINEAR,
52+
clip_values: Optional["CLIP_VALUES_TYPE"] = None,
53+
apply_fit: bool = True,
54+
apply_predict: bool = False,
55+
verbose: bool = False,
56+
):
57+
"""
58+
Create an instance of ImageResize.
59+
60+
:param height: The height of the resized image.
61+
:param width: The width of the resized image.
62+
:param channels_first: Set channels first or last.
63+
:param label_type: String defining the label type. Currently supported: `classification`, `object_detection`
64+
:param interpolation: The desired method to resize the image defined by the `cv2::InterpolationFlags` enum.
65+
:param clip_values: Tuple of the form `(min, max)` representing the minimum and maximum values allowed
66+
for features.
67+
:param apply_fit: True if applied during fitting/training.
68+
:param apply_predict: True if applied during predicting.
69+
:param verbose: Show progress bars.
70+
"""
71+
super().__init__(is_fitted=True, apply_fit=apply_fit, apply_predict=apply_predict)
72+
self.height = height
73+
self.width = width
74+
self.channels_first = channels_first
75+
self.label_type = label_type
76+
self.interpolation = interpolation
77+
self.clip_values = clip_values
78+
self.verbose = verbose
79+
self._check_params()
80+
81+
def __call__(
82+
self, x: np.ndarray, y: Optional[Union[np.ndarray, List[Dict[str, np.ndarray]]]] = None
83+
) -> Tuple[np.ndarray, Optional[Union[np.ndarray, List[Dict[str, np.ndarray]]]]]:
84+
"""
85+
Resize `x` and adjust bounding boxes for labels `y` accordingly.
86+
87+
:param x: Input samples. A list of samples is also supported.
88+
:param y: Label of the samples `x`.
89+
:return: Transformed samples and labels.
90+
"""
91+
x_preprocess_list = []
92+
y_preprocess: Optional[Union[np.ndarray, List[Dict[str, np.ndarray]]]]
93+
if y is not None and self.label_type == "object_detection":
94+
y_preprocess = []
95+
else:
96+
y_preprocess = y
97+
98+
for i, x_i in enumerate(tqdm(x, desc="ImageResize", disable=not self.verbose)):
99+
if self.channels_first:
100+
x_i = np.transpose(x_i, (1, 2, 0))
101+
102+
# Resize image: OpenCV swaps height and width
103+
x_resized = cv2.resize(x_i, (self.width, self.height), interpolation=self.interpolation)
104+
105+
if self.channels_first:
106+
x_resized = np.transpose(x_resized, (2, 0, 1))
107+
108+
x_preprocess_list.append(x_resized)
109+
110+
if self.label_type == "object_detection" and y is not None:
111+
y_resized: Dict[str, np.ndarray] = {}
112+
113+
# Copy labels and ensure types
114+
if isinstance(y, list) and isinstance(y_preprocess, list):
115+
y_i = y[i]
116+
if isinstance(y_i, dict):
117+
y_resized = {k: np.copy(v) for k, v in y_i.items()}
118+
else:
119+
raise TypeError("Wrong type for `y` and label_type=object_detection.")
120+
else:
121+
raise TypeError("Wrong type for `y` and label_type=object_detection.")
122+
123+
# Calculate scaling factor
124+
height, width, _ = x_i.shape
125+
height_scale = self.height / height
126+
width_scale = self.width / width
127+
128+
# Resize bounding boxes
129+
y_resized["boxes"][:, 0] *= width_scale
130+
y_resized["boxes"][:, 1] *= height_scale
131+
y_resized["boxes"][:, 2] *= width_scale
132+
y_resized["boxes"][:, 3] *= height_scale
133+
134+
y_preprocess.append(y_resized)
135+
136+
x_preprocess = np.stack(x_preprocess_list)
137+
if self.clip_values is not None:
138+
x_preprocess = np.clip(x_preprocess, self.clip_values[0], self.clip_values[1])
139+
140+
return x_preprocess, y_preprocess
141+
142+
def _check_params(self) -> None:
143+
if self.height <= 0:
144+
raise ValueError("The desired image height must be positive.")
145+
146+
if self.width <= 0:
147+
raise ValueError("The desired image width must be positive")
148+
149+
if self.clip_values is not None:
150+
if len(self.clip_values) != 2:
151+
raise ValueError("`clip_values` should be a tuple of 2 floats containing the allowed data range.")
152+
153+
if self.clip_values[0] >= self.clip_values[1]:
154+
raise ValueError("Invalid `clip_values`: min >= max.")
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# MIT License
2+
#
3+
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2023
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
6+
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
7+
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
8+
# persons to whom the Software is furnished to do so, subject to the following conditions:
9+
#
10+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
11+
# Software.
12+
#
13+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
14+
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
15+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
16+
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
17+
# SOFTWARE.
18+
"""
19+
This module implements resizing for images and object detection bounding boxes in PyTorch.
20+
"""
21+
import logging
22+
from typing import Dict, List, Optional, TYPE_CHECKING, Tuple, Union
23+
24+
from tqdm.auto import tqdm
25+
26+
from art.preprocessing.preprocessing import PreprocessorPyTorch
27+
28+
if TYPE_CHECKING:
29+
# pylint: disable=C0412
30+
import torch
31+
from art.utils import CLIP_VALUES_TYPE
32+
33+
logger = logging.getLogger(__name__)
34+
35+
36+
class ImageResizePyTorch(PreprocessorPyTorch):
37+
"""
38+
This module implements resizing for images and object detection bounding boxes in PyTorch.
39+
"""
40+
41+
params = ["height", "width", "channels_first", "label_type", "interpolation", "clip_values", "verbose"]
42+
43+
label_types = ["classification", "object_detection"]
44+
45+
def __init__(
46+
self,
47+
height: int,
48+
width: int,
49+
channels_first: bool = True,
50+
label_type: str = "classification",
51+
interpolation: str = "bilinear",
52+
clip_values: Optional["CLIP_VALUES_TYPE"] = None,
53+
apply_fit: bool = True,
54+
apply_predict: bool = False,
55+
verbose: bool = False,
56+
):
57+
"""
58+
Create an instance of ImageResizePyTorch.
59+
60+
:param height: The height of the resized image.
61+
:param width: The width of the resized image.
62+
:param channels_first: Set channels first or last.
63+
:param label_type: String defining the label type. Currently supported: `classification`, `object_detection`
64+
:param interpolation: String defining the resizing method. Currently supported: `nearest`, `linear`,
65+
`bilinear`, `bicubic`, `trilinear`, `area`, `nearest-exact`
66+
:param clip_values: Tuple of the form `(min, max)` representing the minimum and maximum values allowed
67+
for features.
68+
:param apply_fit: True if applied during fitting/training.
69+
:param apply_predict: True if applied during predicting.
70+
:param verbose: Show progress bars.
71+
"""
72+
super().__init__(is_fitted=True, apply_fit=apply_fit, apply_predict=apply_predict)
73+
self.height = height
74+
self.width = width
75+
self.channels_first = channels_first
76+
self.label_type = label_type
77+
self.interpolation = interpolation
78+
self.clip_values = clip_values
79+
self.verbose = verbose
80+
self._check_params()
81+
82+
def forward(
83+
self,
84+
x: "torch.Tensor",
85+
y: Optional[Union["torch.Tensor", List[Dict[str, "torch.Tensor"]]]] = None,
86+
) -> Tuple["torch.Tensor", Optional[Union["torch.Tensor", List[Dict[str, "torch.Tensor"]]]]]:
87+
"""
88+
Resize `x` and adjust bounding boxes for labels `y` accordingly.
89+
90+
:param x: Input samples. A list of samples is also supported.
91+
:param y: Label of the samples `x`.
92+
:return: Transformed samples and labels.
93+
"""
94+
import torch
95+
96+
x_preprocess_list = []
97+
y_preprocess: Optional[Union[torch.Tensor, List[Dict[str, torch.Tensor]]]]
98+
if y is not None and self.label_type == "object_detection":
99+
y_preprocess = []
100+
else:
101+
y_preprocess = y
102+
103+
for i, x_i in enumerate(tqdm(x, desc="ImageResizePyTorch", disable=not self.verbose)):
104+
if not self.channels_first:
105+
x_i = torch.permute(x_i, (2, 0, 1))
106+
107+
# Resize image: requires a batch so create batch of size 1
108+
x_resized = torch.nn.functional.interpolate(
109+
x_i.unsqueeze(0), size=(self.height, self.width), mode=self.interpolation
110+
).squeeze()
111+
112+
if not self.channels_first:
113+
x_resized = torch.permute(x_resized, (1, 2, 0))
114+
115+
x_preprocess_list.append(x_resized)
116+
117+
if self.label_type == "object_detection" and y is not None:
118+
y_resized: Dict[str, torch.Tensor] = {}
119+
120+
# Copy labels and ensure types
121+
if isinstance(y, list) and isinstance(y_preprocess, list):
122+
y_i = y[i]
123+
if isinstance(y_i, dict):
124+
y_resized = {k: torch.clone(v) for k, v in y_i.items()}
125+
else:
126+
raise TypeError("Wrong type for `y` and label_type=object_detection.")
127+
else:
128+
raise TypeError("Wrong type for `y` and label_type=object_detection.")
129+
130+
# Calculate scaling factor
131+
_, height, width = x_i.shape
132+
height_scale = self.height / height
133+
width_scale = self.width / width
134+
135+
# Resize bounding boxes
136+
y_resized["boxes"][:, 0] *= width_scale
137+
y_resized["boxes"][:, 1] *= height_scale
138+
y_resized["boxes"][:, 2] *= width_scale
139+
y_resized["boxes"][:, 3] *= height_scale
140+
141+
y_preprocess.append(y_resized)
142+
143+
x_preprocess = torch.stack(x_preprocess_list)
144+
if self.clip_values is not None:
145+
x_preprocess = torch.clamp(x_preprocess, self.clip_values[0], self.clip_values[1]) # type: ignore
146+
147+
return x_preprocess, y_preprocess
148+
149+
def _check_params(self) -> None:
150+
if self.height <= 0:
151+
raise ValueError("The desired image height must be positive.")
152+
153+
if self.width <= 0:
154+
raise ValueError("The desired image width must be positive")
155+
156+
if self.clip_values is not None:
157+
if len(self.clip_values) != 2:
158+
raise ValueError("`clip_values` should be a tuple of 2 floats containing the allowed data range.")
159+
160+
if self.clip_values[0] >= self.clip_values[1]:
161+
raise ValueError("Invalid `clip_values`: min >= max.")

0 commit comments

Comments
 (0)