Skip to content

Commit be5837d

Browse files
authored
Add RandomTranslation to KerasCV (#1314)
* Add RandomTranslation to KerasCV Let's keep this PR lightweight for now, then we can tidy the layer to fit standards later. Right now we are broken per Keras' latest release. * Add copyright * Lint fix * Fix mixed precision test
1 parent 37cd993 commit be5837d

File tree

4 files changed

+470
-1
lines changed

4 files changed

+470
-1
lines changed

keras_cv/layers/preprocessing/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
from tensorflow.keras.layers import CenterCrop
1919
from tensorflow.keras.layers import RandomHeight
20-
from tensorflow.keras.layers import RandomTranslation
2120
from tensorflow.keras.layers import RandomWidth
2221

2322
from keras_cv.layers.preprocessing.aug_mix import AugMix
@@ -59,6 +58,7 @@
5958
from keras_cv.layers.preprocessing.random_saturation import RandomSaturation
6059
from keras_cv.layers.preprocessing.random_sharpness import RandomSharpness
6160
from keras_cv.layers.preprocessing.random_shear import RandomShear
61+
from keras_cv.layers.preprocessing.random_translation import RandomTranslation
6262
from keras_cv.layers.preprocessing.random_zoom import RandomZoom
6363
from keras_cv.layers.preprocessing.randomly_zoomed_crop import RandomlyZoomedCrop
6464
from keras_cv.layers.preprocessing.repeated_augmentation import RepeatedAugmentation
Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
# Copyright 2023 The KerasCV Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import tensorflow as tf
16+
from keras import backend
17+
18+
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
19+
BaseImageAugmentationLayer,
20+
)
21+
from keras_cv.utils import preprocessing
22+
23+
24+
def check_fill_mode_and_interpolation(fill_mode, interpolation):
25+
if fill_mode not in {"reflect", "wrap", "constant", "nearest"}:
26+
raise NotImplementedError(
27+
f"Unknown `fill_mode` {fill_mode}. Only `reflect`, `wrap`, "
28+
"`constant` and `nearest` are supported."
29+
)
30+
if interpolation not in {"nearest", "bilinear"}:
31+
raise NotImplementedError(
32+
f"Unknown `interpolation` {interpolation}. Only `nearest` and "
33+
"`bilinear` are supported."
34+
)
35+
36+
37+
def get_translation_matrix(translations, name=None):
38+
"""Returns projective transform(s) for the given translation(s).
39+
40+
Args:
41+
translations: A matrix of 2-element lists representing `[dx, dy]`
42+
to translate for each image (for a batch of images).
43+
name: The name of the op.
44+
45+
Returns:
46+
A tensor of shape `(num_images, 8)` projective transforms which can be
47+
given to `transform`.
48+
"""
49+
with backend.name_scope(name or "translation_matrix"):
50+
num_translations = tf.shape(translations)[0]
51+
# The translation matrix looks like:
52+
# [[1 0 -dx]
53+
# [0 1 -dy]
54+
# [0 0 1]]
55+
# where the last entry is implicit.
56+
# Translation matrices are always float32.
57+
return tf.concat(
58+
values=[
59+
tf.ones((num_translations, 1), tf.float32),
60+
tf.zeros((num_translations, 1), tf.float32),
61+
-translations[:, 0, None],
62+
tf.zeros((num_translations, 1), tf.float32),
63+
tf.ones((num_translations, 1), tf.float32),
64+
-translations[:, 1, None],
65+
tf.zeros((num_translations, 2), tf.float32),
66+
],
67+
axis=1,
68+
)
69+
70+
71+
H_AXIS = -3
72+
W_AXIS = -2
73+
74+
75+
class RandomTranslation(BaseImageAugmentationLayer):
76+
"""A preprocessing layer which randomly translates images during training.
77+
78+
This layer will apply random translations to each image during training,
79+
filling empty space according to `fill_mode`.
80+
81+
Input pixel values can be of any range (e.g. `[0., 1.)` or `[0, 255]`) and
82+
of integer or floating point dtype. By default, the layer will output
83+
floats.
84+
85+
Args:
86+
height_factor: a float represented as fraction of value, or a tuple of
87+
size 2 representing lower and upper bound for shifting vertically. A
88+
negative value means shifting image up, while a positive value means
89+
shifting image down. When represented as a single positive float, this
90+
value is used for both the upper and lower bound. For instance,
91+
`height_factor=(-0.2, 0.3)` results in an output shifted by a random
92+
amount in the range `[-20%, +30%]`. `height_factor=0.2` results in an
93+
output height shifted by a random amount in the range `[-20%, +20%]`.
94+
width_factor: a float represented as fraction of value, or a tuple of size
95+
2 representing lower and upper bound for shifting horizontally. A
96+
negative value means shifting image left, while a positive value means
97+
shifting image right. When represented as a single positive float,
98+
this value is used for both the upper and lower bound. For instance,
99+
`width_factor=(-0.2, 0.3)` results in an output shifted left by 20%,
100+
and shifted right by 30%. `width_factor=0.2` results
101+
in an output height shifted left or right by 20%.
102+
fill_mode: Points outside the boundaries of the input are filled according
103+
to the given mode
104+
(one of `{"constant", "reflect", "wrap", "nearest"}`).
105+
- *reflect*: `(d c b a | a b c d | d c b a)` The input is extended by
106+
reflecting about the edge of the last pixel.
107+
- *constant*: `(k k k k | a b c d | k k k k)` The input is extended by
108+
filling all values beyond the edge with the same constant value
109+
k = 0.
110+
- *wrap*: `(a b c d | a b c d | a b c d)` The input is extended by
111+
wrapping around to the opposite edge.
112+
- *nearest*: `(a a a a | a b c d | d d d d)` The input is extended by
113+
the nearest pixel.
114+
interpolation: Interpolation mode. Supported values: `"nearest"`,
115+
`"bilinear"`.
116+
seed: Integer. Used to create a random seed.
117+
fill_value: a float represents the value to be filled outside the
118+
boundaries when `fill_mode="constant"`.
119+
120+
Input shape:
121+
3D (unbatched) or 4D (batched) tensor with shape:
122+
`(..., height, width, channels)`, in `"channels_last"` format.
123+
124+
Output shape:
125+
3D (unbatched) or 4D (batched) tensor with shape:
126+
`(..., height, width, channels)`, in `"channels_last"` format.
127+
"""
128+
129+
def __init__(
130+
self,
131+
height_factor,
132+
width_factor,
133+
fill_mode="reflect",
134+
interpolation="bilinear",
135+
seed=None,
136+
fill_value=0.0,
137+
**kwargs,
138+
):
139+
super().__init__(seed=seed, force_generator=True, **kwargs)
140+
self.height_factor = height_factor
141+
if isinstance(height_factor, (tuple, list)):
142+
self.height_lower = height_factor[0]
143+
self.height_upper = height_factor[1]
144+
else:
145+
self.height_lower = -height_factor
146+
self.height_upper = height_factor
147+
if self.height_upper < self.height_lower:
148+
raise ValueError(
149+
"`height_factor` cannot have upper bound less than "
150+
f"lower bound, got {height_factor}"
151+
)
152+
if abs(self.height_lower) > 1.0 or abs(self.height_upper) > 1.0:
153+
raise ValueError(
154+
"`height_factor` must have values between [-1, 1], "
155+
f"got {height_factor}"
156+
)
157+
158+
self.width_factor = width_factor
159+
if isinstance(width_factor, (tuple, list)):
160+
self.width_lower = width_factor[0]
161+
self.width_upper = width_factor[1]
162+
else:
163+
self.width_lower = -width_factor
164+
self.width_upper = width_factor
165+
if self.width_upper < self.width_lower:
166+
raise ValueError(
167+
"`width_factor` cannot have upper bound less than "
168+
f"lower bound, got {width_factor}"
169+
)
170+
if abs(self.width_lower) > 1.0 or abs(self.width_upper) > 1.0:
171+
raise ValueError(
172+
"`width_factor` must have values between [-1, 1], "
173+
f"got {width_factor}"
174+
)
175+
176+
check_fill_mode_and_interpolation(fill_mode, interpolation)
177+
178+
self.fill_mode = fill_mode
179+
self.fill_value = fill_value
180+
self.interpolation = interpolation
181+
self.seed = seed
182+
183+
def augment_image(self, image, transformation, **kwargs):
184+
"""Translated inputs with random ops."""
185+
# The transform op only accepts rank 4 inputs, so if we have an
186+
# unbatched image, we need to temporarily expand dims to a batch.
187+
original_shape = image.shape
188+
inputs = tf.expand_dims(image, 0)
189+
190+
inputs_shape = tf.shape(inputs)
191+
img_hd = tf.cast(inputs_shape[H_AXIS], tf.float32)
192+
img_wd = tf.cast(inputs_shape[W_AXIS], tf.float32)
193+
height_translation = transformation["height_translation"]
194+
width_translation = transformation["width_translation"]
195+
height_translation = height_translation * img_hd
196+
width_translation = width_translation * img_wd
197+
translations = tf.cast(
198+
tf.concat([width_translation, height_translation], axis=1),
199+
dtype=tf.float32,
200+
)
201+
output = preprocessing.transform(
202+
inputs,
203+
get_translation_matrix(translations),
204+
interpolation=self.interpolation,
205+
fill_mode=self.fill_mode,
206+
fill_value=self.fill_value,
207+
)
208+
209+
output = tf.squeeze(output, 0)
210+
output.set_shape(original_shape)
211+
return output
212+
213+
def get_random_transformation(self, image=None, **kwargs):
214+
batch_size = 1
215+
height_translation = self._random_generator.random_uniform(
216+
shape=[batch_size, 1],
217+
minval=self.height_lower,
218+
maxval=self.height_upper,
219+
dtype=tf.float32,
220+
)
221+
width_translation = self._random_generator.random_uniform(
222+
shape=[batch_size, 1],
223+
minval=self.width_lower,
224+
maxval=self.width_upper,
225+
dtype=tf.float32,
226+
)
227+
return {
228+
"height_translation": height_translation,
229+
"width_translation": width_translation,
230+
}
231+
232+
def _batch_augment(self, inputs):
233+
# Change to vectorized_map for better performance, as well as work
234+
# around issue for different tensorspec between inputs and outputs.
235+
return tf.vectorized_map(self._augment, inputs)
236+
237+
def augment_label(self, label, transformation, **kwargs):
238+
return label
239+
240+
def compute_output_shape(self, input_shape):
241+
return input_shape
242+
243+
def get_config(self):
244+
config = {
245+
"height_factor": self.height_factor,
246+
"width_factor": self.width_factor,
247+
"fill_mode": self.fill_mode,
248+
"fill_value": self.fill_value,
249+
"interpolation": self.interpolation,
250+
"seed": self.seed,
251+
}
252+
base_config = super().get_config()
253+
return dict(list(base_config.items()) + list(config.items()))

0 commit comments

Comments
 (0)