Skip to content

Commit 0e7610e

Browse files
committed
Add crop_to_aspect_ratio arg in keras.ops.image.resize.
1 parent 7dae3e9 commit 0e7610e

File tree

6 files changed

+181
-18
lines changed

6 files changed

+181
-18
lines changed

keras/backend/jax/image.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def resize(
4242
size,
4343
interpolation="bilinear",
4444
antialias=False,
45+
crop_to_aspect_ratio=False,
4546
data_format="channels_last",
4647
):
4748
if interpolation not in RESIZE_INTERPOLATIONS:
@@ -55,6 +56,7 @@ def resize(
5556
f"(height, width). Received: size={size}"
5657
)
5758
size = tuple(size)
59+
target_height, target_width = size
5860
if len(image.shape) == 4:
5961
if data_format == "channels_last":
6062
size = (image.shape[0],) + size + (image.shape[-1],)
@@ -71,6 +73,48 @@ def resize(
7173
"or rank 4 (batch of images). Received input with shape: "
7274
f"image.shape={image.shape}"
7375
)
76+
77+
if crop_to_aspect_ratio:
78+
shape = image.shape
79+
if data_format == "channels_last":
80+
height, width = shape[-3], shape[-2]
81+
else:
82+
height, width = shape[-2], shape[-1]
83+
crop_height = int(float(width * target_height) / target_width)
84+
crop_height = min(height, crop_height)
85+
crop_width = int(float(height * target_width) / target_height)
86+
crop_width = min(width, crop_width)
87+
crop_box_hstart = int(float(height - crop_height) / 2)
88+
crop_box_wstart = int(float(width - crop_width) / 2)
89+
if data_format == "channels_last":
90+
if len(image.shape) == 4:
91+
image = image[
92+
:,
93+
crop_box_hstart : crop_box_hstart + crop_height,
94+
crop_box_wstart : crop_box_wstart + crop_width,
95+
:,
96+
]
97+
else:
98+
image = image[
99+
crop_box_hstart : crop_box_hstart + crop_height,
100+
crop_box_wstart : crop_box_wstart + crop_width,
101+
:,
102+
]
103+
else:
104+
if len(image.shape) == 4:
105+
image = image[
106+
:,
107+
:,
108+
crop_box_hstart : crop_box_hstart + crop_height,
109+
crop_box_wstart : crop_box_wstart + crop_width,
110+
]
111+
else:
112+
image = image[
113+
:,
114+
crop_box_hstart : crop_box_hstart + crop_height,
115+
crop_box_wstart : crop_box_wstart + crop_width,
116+
]
117+
74118
return jax.image.resize(
75119
image, size, method=interpolation, antialias=antialias
76120
)

keras/backend/numpy/image.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def resize(
4141
size,
4242
interpolation="bilinear",
4343
antialias=False,
44+
crop_to_aspect_ratio=False,
4445
data_format="channels_last",
4546
):
4647
if interpolation not in RESIZE_INTERPOLATIONS:
@@ -54,6 +55,7 @@ def resize(
5455
f"(height, width). Received: size={size}"
5556
)
5657
size = tuple(size)
58+
target_height, target_width = size
5759
if len(image.shape) == 4:
5860
if data_format == "channels_last":
5961
size = (image.shape[0],) + size + (image.shape[-1],)
@@ -70,6 +72,48 @@ def resize(
7072
"or rank 4 (batch of images). Received input with shape: "
7173
f"image.shape={image.shape}"
7274
)
75+
76+
if crop_to_aspect_ratio:
77+
shape = image.shape
78+
if data_format == "channels_last":
79+
height, width = shape[-3], shape[-2]
80+
else:
81+
height, width = shape[-2], shape[-1]
82+
crop_height = int(float(width * target_height) / target_width)
83+
crop_height = min(height, crop_height)
84+
crop_width = int(float(height * target_width) / target_height)
85+
crop_width = min(width, crop_width)
86+
crop_box_hstart = int(float(height - crop_height) / 2)
87+
crop_box_wstart = int(float(width - crop_width) / 2)
88+
if data_format == "channels_last":
89+
if len(image.shape) == 4:
90+
image = image[
91+
:,
92+
crop_box_hstart : crop_box_hstart + crop_height,
93+
crop_box_wstart : crop_box_wstart + crop_width,
94+
:,
95+
]
96+
else:
97+
image = image[
98+
crop_box_hstart : crop_box_hstart + crop_height,
99+
crop_box_wstart : crop_box_wstart + crop_width,
100+
:,
101+
]
102+
else:
103+
if len(image.shape) == 4:
104+
image = image[
105+
:,
106+
:,
107+
crop_box_hstart : crop_box_hstart + crop_height,
108+
crop_box_wstart : crop_box_wstart + crop_width,
109+
]
110+
else:
111+
image = image[
112+
:,
113+
crop_box_hstart : crop_box_hstart + crop_height,
114+
crop_box_wstart : crop_box_wstart + crop_width,
115+
]
116+
73117
return np.array(
74118
jax.image.resize(image, size, method=interpolation, antialias=antialias)
75119
)

keras/backend/tensorflow/image.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def resize(
4141
size,
4242
interpolation="bilinear",
4343
antialias=False,
44+
crop_to_aspect_ratio=False,
4445
data_format="channels_last",
4546
):
4647
if interpolation not in RESIZE_INTERPOLATIONS:
@@ -65,6 +66,42 @@ def resize(
6566
"or rank 4 (batch of images). Received input with shape: "
6667
f"image.shape={image.shape}"
6768
)
69+
if crop_to_aspect_ratio:
70+
shape = tf.shape(image)
71+
height, width = shape[-3], shape[-2]
72+
target_height, target_width = size
73+
crop_height = tf.cast(
74+
tf.cast(width * target_height, "float32") / target_width,
75+
"int32",
76+
)
77+
crop_height = tf.minimum(height, crop_height)
78+
crop_height = tf.cast(crop_height, "int32")
79+
crop_width = tf.cast(
80+
tf.cast(height * target_width, "float32") / target_height,
81+
"int32",
82+
)
83+
crop_width = tf.minimum(width, crop_width)
84+
crop_width = tf.cast(crop_width, "int32")
85+
86+
crop_box_hstart = tf.cast(
87+
tf.cast(height - crop_height, "float32") / 2, "int32"
88+
)
89+
crop_box_wstart = tf.cast(
90+
tf.cast(width - crop_width, "float32") / 2, "int32"
91+
)
92+
if len(image.shape) == 4:
93+
image = image[
94+
:,
95+
crop_box_hstart : crop_box_hstart + crop_height,
96+
crop_box_wstart : crop_box_wstart + crop_width,
97+
:,
98+
]
99+
else:
100+
image = image[
101+
crop_box_hstart : crop_box_hstart + crop_height,
102+
crop_box_wstart : crop_box_wstart + crop_width,
103+
:,
104+
]
68105

69106
resized = tf.image.resize(
70107
image, size, method=interpolation, antialias=antialias

keras/backend/torch/image.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def resize(
5151
size,
5252
interpolation="bilinear",
5353
antialias=False,
54+
crop_to_aspect_ratio=False,
5455
data_format="channels_last",
5556
):
5657
try:
@@ -99,6 +100,30 @@ def resize(
99100
f"image.shape={image.shape}"
100101
)
101102

103+
if crop_to_aspect_ratio:
104+
shape = image.shape
105+
height, width = shape[-2], shape[-1]
106+
target_height, target_width = size
107+
crop_height = int(float(width * target_height) / target_width)
108+
crop_height = min(height, crop_height)
109+
crop_width = int(float(height * target_width) / target_height)
110+
crop_width = min(width, crop_width)
111+
crop_box_hstart = int(float(height - crop_height) / 2)
112+
crop_box_wstart = int(float(width - crop_width) / 2)
113+
if len(image.shape) == 4:
114+
image = image[
115+
:,
116+
:,
117+
crop_box_hstart : crop_box_hstart + crop_height,
118+
crop_box_wstart : crop_box_wstart + crop_width,
119+
]
120+
else:
121+
image = image[
122+
:,
123+
crop_box_hstart : crop_box_hstart + crop_height,
124+
crop_box_wstart : crop_box_wstart + crop_width,
125+
]
126+
102127
resized = torchvision.transforms.functional.resize(
103128
img=image,
104129
size=size,

keras/layers/preprocessing/resizing.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from keras import backend
22
from keras.api_export import keras_export
33
from keras.layers.preprocessing.tf_data_layer import TFDataLayer
4-
from keras.utils import image_utils
54

65

76
@keras_export("keras.layers.Resizing")
@@ -69,22 +68,13 @@ def __init__(
6968

7069
def call(self, inputs):
7170
size = (self.height, self.width)
72-
if self.crop_to_aspect_ratio:
73-
outputs = image_utils.smart_resize(
74-
inputs,
75-
size=size,
76-
interpolation=self.interpolation,
77-
data_format=self.data_format,
78-
backend_module=self.backend,
79-
)
80-
else:
81-
outputs = self.backend.image.resize(
82-
inputs,
83-
size=size,
84-
interpolation=self.interpolation,
85-
data_format=self.data_format,
86-
)
87-
return outputs
71+
return self.backend.image.resize(
72+
inputs,
73+
size=size,
74+
interpolation=self.interpolation,
75+
data_format=self.data_format,
76+
crop_to_aspect_ratio=self.crop_to_aspect_ratio,
77+
)
8878

8979
def compute_output_shape(self, input_shape):
9080
input_shape = list(input_shape)

keras/ops/image.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,13 +110,15 @@ def __init__(
110110
size,
111111
interpolation="bilinear",
112112
antialias=False,
113+
crop_to_aspect_ratio=False,
113114
data_format="channels_last",
114115
):
115116
super().__init__()
116117
self.size = tuple(size)
117118
self.interpolation = interpolation
118119
self.antialias = antialias
119120
self.data_format = data_format
121+
self.crop_to_aspect_ratio = crop_to_aspect_ratio
120122

121123
def call(self, image):
122124
return backend.image.resize(
@@ -125,6 +127,7 @@ def call(self, image):
125127
interpolation=self.interpolation,
126128
antialias=self.antialias,
127129
data_format=self.data_format,
130+
crop_to_aspect_ratio=self.crop_to_aspect_ratio,
128131
)
129132

130133
def compute_output_spec(self, image):
@@ -156,6 +159,7 @@ def resize(
156159
size,
157160
interpolation="bilinear",
158161
antialias=False,
162+
crop_to_aspect_ratio=False,
159163
data_format="channels_last",
160164
):
161165
"""Resize images to size using the specified interpolation method.
@@ -167,6 +171,13 @@ def resize(
167171
`"bilinear"`, and `"bicubic"`. Defaults to `"bilinear"`.
168172
antialias: Whether to use an antialiasing filter when downsampling an
169173
image. Defaults to `False`.
174+
crop_to_aspect_ratio: If `True`, resize the images without aspect
175+
ratio distortion. When the original aspect ratio differs
176+
from the target aspect ratio, the output image will be
177+
cropped so as to return the
178+
largest possible window in the image (of size `(height, width)`)
179+
that matches the target aspect ratio. By default
180+
(`crop_to_aspect_ratio=False`), aspect ratio may not be preserved.
170181
data_format: string, either `"channels_last"` or `"channels_first"`.
171182
The ordering of the dimensions in the inputs. `"channels_last"`
172183
corresponds to inputs with shape `(batch, height, width, channels)`
@@ -197,19 +208,31 @@ def resize(
197208
>>> y.shape
198209
(2, 3, 2, 2)
199210
"""
200-
211+
if len(size) != 2:
212+
raise ValueError(
213+
"Expected `size` to be a tuple of 2 integers. "
214+
f"Received: size={size}"
215+
)
216+
if len(image.shape) < 3 or len(image.shape) > 4:
217+
raise ValueError(
218+
"Expected an image array with shape `(height, width, "
219+
"channels)`, or `(batch_size, height, width, channels)`, but "
220+
f"got input with incorrect rank, of shape {image.shape}."
221+
)
201222
if any_symbolic_tensors((image,)):
202223
return Resize(
203224
size,
204225
interpolation=interpolation,
205226
antialias=antialias,
206227
data_format=data_format,
228+
crop_to_aspect_ratio=crop_to_aspect_ratio,
207229
).symbolic_call(image)
208230
return backend.image.resize(
209231
image,
210232
size,
211233
interpolation=interpolation,
212234
antialias=antialias,
235+
crop_to_aspect_ratio=crop_to_aspect_ratio,
213236
data_format=data_format,
214237
)
215238

0 commit comments

Comments
 (0)