|
| 1 | +# Copyright 2023 The KerasCV Authors |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import tensorflow as tf |
| 16 | +from keras import backend |
| 17 | + |
| 18 | +from keras_cv.layers.preprocessing.base_image_augmentation_layer import ( |
| 19 | + BaseImageAugmentationLayer, |
| 20 | +) |
| 21 | +from keras_cv.utils import preprocessing |
| 22 | + |
| 23 | + |
| 24 | +def check_fill_mode_and_interpolation(fill_mode, interpolation): |
| 25 | + if fill_mode not in {"reflect", "wrap", "constant", "nearest"}: |
| 26 | + raise NotImplementedError( |
| 27 | + f"Unknown `fill_mode` {fill_mode}. Only `reflect`, `wrap`, " |
| 28 | + "`constant` and `nearest` are supported." |
| 29 | + ) |
| 30 | + if interpolation not in {"nearest", "bilinear"}: |
| 31 | + raise NotImplementedError( |
| 32 | + f"Unknown `interpolation` {interpolation}. Only `nearest` and " |
| 33 | + "`bilinear` are supported." |
| 34 | + ) |
| 35 | + |
| 36 | + |
| 37 | +def get_translation_matrix(translations, name=None): |
| 38 | + """Returns projective transform(s) for the given translation(s). |
| 39 | +
|
| 40 | + Args: |
| 41 | + translations: A matrix of 2-element lists representing `[dx, dy]` |
| 42 | + to translate for each image (for a batch of images). |
| 43 | + name: The name of the op. |
| 44 | +
|
| 45 | + Returns: |
| 46 | + A tensor of shape `(num_images, 8)` projective transforms which can be |
| 47 | + given to `transform`. |
| 48 | + """ |
| 49 | + with backend.name_scope(name or "translation_matrix"): |
| 50 | + num_translations = tf.shape(translations)[0] |
| 51 | + # The translation matrix looks like: |
| 52 | + # [[1 0 -dx] |
| 53 | + # [0 1 -dy] |
| 54 | + # [0 0 1]] |
| 55 | + # where the last entry is implicit. |
| 56 | + # Translation matrices are always float32. |
| 57 | + return tf.concat( |
| 58 | + values=[ |
| 59 | + tf.ones((num_translations, 1), tf.float32), |
| 60 | + tf.zeros((num_translations, 1), tf.float32), |
| 61 | + -translations[:, 0, None], |
| 62 | + tf.zeros((num_translations, 1), tf.float32), |
| 63 | + tf.ones((num_translations, 1), tf.float32), |
| 64 | + -translations[:, 1, None], |
| 65 | + tf.zeros((num_translations, 2), tf.float32), |
| 66 | + ], |
| 67 | + axis=1, |
| 68 | + ) |
| 69 | + |
| 70 | + |
| 71 | +H_AXIS = -3 |
| 72 | +W_AXIS = -2 |
| 73 | + |
| 74 | + |
| 75 | +class RandomTranslation(BaseImageAugmentationLayer): |
| 76 | + """A preprocessing layer which randomly translates images during training. |
| 77 | +
|
| 78 | + This layer will apply random translations to each image during training, |
| 79 | + filling empty space according to `fill_mode`. |
| 80 | +
|
| 81 | + Input pixel values can be of any range (e.g. `[0., 1.)` or `[0, 255]`) and |
| 82 | + of integer or floating point dtype. By default, the layer will output |
| 83 | + floats. |
| 84 | +
|
| 85 | + Args: |
| 86 | + height_factor: a float represented as fraction of value, or a tuple of |
| 87 | + size 2 representing lower and upper bound for shifting vertically. A |
| 88 | + negative value means shifting image up, while a positive value means |
| 89 | + shifting image down. When represented as a single positive float, this |
| 90 | + value is used for both the upper and lower bound. For instance, |
| 91 | + `height_factor=(-0.2, 0.3)` results in an output shifted by a random |
| 92 | + amount in the range `[-20%, +30%]`. `height_factor=0.2` results in an |
| 93 | + output height shifted by a random amount in the range `[-20%, +20%]`. |
| 94 | + width_factor: a float represented as fraction of value, or a tuple of size |
| 95 | + 2 representing lower and upper bound for shifting horizontally. A |
| 96 | + negative value means shifting image left, while a positive value means |
| 97 | + shifting image right. When represented as a single positive float, |
| 98 | + this value is used for both the upper and lower bound. For instance, |
| 99 | + `width_factor=(-0.2, 0.3)` results in an output shifted left by 20%, |
| 100 | + and shifted right by 30%. `width_factor=0.2` results |
| 101 | + in an output height shifted left or right by 20%. |
| 102 | + fill_mode: Points outside the boundaries of the input are filled according |
| 103 | + to the given mode |
| 104 | + (one of `{"constant", "reflect", "wrap", "nearest"}`). |
| 105 | + - *reflect*: `(d c b a | a b c d | d c b a)` The input is extended by |
| 106 | + reflecting about the edge of the last pixel. |
| 107 | + - *constant*: `(k k k k | a b c d | k k k k)` The input is extended by |
| 108 | + filling all values beyond the edge with the same constant value |
| 109 | + k = 0. |
| 110 | + - *wrap*: `(a b c d | a b c d | a b c d)` The input is extended by |
| 111 | + wrapping around to the opposite edge. |
| 112 | + - *nearest*: `(a a a a | a b c d | d d d d)` The input is extended by |
| 113 | + the nearest pixel. |
| 114 | + interpolation: Interpolation mode. Supported values: `"nearest"`, |
| 115 | + `"bilinear"`. |
| 116 | + seed: Integer. Used to create a random seed. |
| 117 | + fill_value: a float represents the value to be filled outside the |
| 118 | + boundaries when `fill_mode="constant"`. |
| 119 | +
|
| 120 | + Input shape: |
| 121 | + 3D (unbatched) or 4D (batched) tensor with shape: |
| 122 | + `(..., height, width, channels)`, in `"channels_last"` format. |
| 123 | +
|
| 124 | + Output shape: |
| 125 | + 3D (unbatched) or 4D (batched) tensor with shape: |
| 126 | + `(..., height, width, channels)`, in `"channels_last"` format. |
| 127 | + """ |
| 128 | + |
| 129 | + def __init__( |
| 130 | + self, |
| 131 | + height_factor, |
| 132 | + width_factor, |
| 133 | + fill_mode="reflect", |
| 134 | + interpolation="bilinear", |
| 135 | + seed=None, |
| 136 | + fill_value=0.0, |
| 137 | + **kwargs, |
| 138 | + ): |
| 139 | + super().__init__(seed=seed, force_generator=True, **kwargs) |
| 140 | + self.height_factor = height_factor |
| 141 | + if isinstance(height_factor, (tuple, list)): |
| 142 | + self.height_lower = height_factor[0] |
| 143 | + self.height_upper = height_factor[1] |
| 144 | + else: |
| 145 | + self.height_lower = -height_factor |
| 146 | + self.height_upper = height_factor |
| 147 | + if self.height_upper < self.height_lower: |
| 148 | + raise ValueError( |
| 149 | + "`height_factor` cannot have upper bound less than " |
| 150 | + f"lower bound, got {height_factor}" |
| 151 | + ) |
| 152 | + if abs(self.height_lower) > 1.0 or abs(self.height_upper) > 1.0: |
| 153 | + raise ValueError( |
| 154 | + "`height_factor` must have values between [-1, 1], " |
| 155 | + f"got {height_factor}" |
| 156 | + ) |
| 157 | + |
| 158 | + self.width_factor = width_factor |
| 159 | + if isinstance(width_factor, (tuple, list)): |
| 160 | + self.width_lower = width_factor[0] |
| 161 | + self.width_upper = width_factor[1] |
| 162 | + else: |
| 163 | + self.width_lower = -width_factor |
| 164 | + self.width_upper = width_factor |
| 165 | + if self.width_upper < self.width_lower: |
| 166 | + raise ValueError( |
| 167 | + "`width_factor` cannot have upper bound less than " |
| 168 | + f"lower bound, got {width_factor}" |
| 169 | + ) |
| 170 | + if abs(self.width_lower) > 1.0 or abs(self.width_upper) > 1.0: |
| 171 | + raise ValueError( |
| 172 | + "`width_factor` must have values between [-1, 1], " |
| 173 | + f"got {width_factor}" |
| 174 | + ) |
| 175 | + |
| 176 | + check_fill_mode_and_interpolation(fill_mode, interpolation) |
| 177 | + |
| 178 | + self.fill_mode = fill_mode |
| 179 | + self.fill_value = fill_value |
| 180 | + self.interpolation = interpolation |
| 181 | + self.seed = seed |
| 182 | + |
| 183 | + def augment_image(self, image, transformation, **kwargs): |
| 184 | + """Translated inputs with random ops.""" |
| 185 | + # The transform op only accepts rank 4 inputs, so if we have an |
| 186 | + # unbatched image, we need to temporarily expand dims to a batch. |
| 187 | + original_shape = image.shape |
| 188 | + inputs = tf.expand_dims(image, 0) |
| 189 | + |
| 190 | + inputs_shape = tf.shape(inputs) |
| 191 | + img_hd = tf.cast(inputs_shape[H_AXIS], tf.float32) |
| 192 | + img_wd = tf.cast(inputs_shape[W_AXIS], tf.float32) |
| 193 | + height_translation = transformation["height_translation"] |
| 194 | + width_translation = transformation["width_translation"] |
| 195 | + height_translation = height_translation * img_hd |
| 196 | + width_translation = width_translation * img_wd |
| 197 | + translations = tf.cast( |
| 198 | + tf.concat([width_translation, height_translation], axis=1), |
| 199 | + dtype=tf.float32, |
| 200 | + ) |
| 201 | + output = preprocessing.transform( |
| 202 | + inputs, |
| 203 | + get_translation_matrix(translations), |
| 204 | + interpolation=self.interpolation, |
| 205 | + fill_mode=self.fill_mode, |
| 206 | + fill_value=self.fill_value, |
| 207 | + ) |
| 208 | + |
| 209 | + output = tf.squeeze(output, 0) |
| 210 | + output.set_shape(original_shape) |
| 211 | + return output |
| 212 | + |
| 213 | + def get_random_transformation(self, image=None, **kwargs): |
| 214 | + batch_size = 1 |
| 215 | + height_translation = self._random_generator.random_uniform( |
| 216 | + shape=[batch_size, 1], |
| 217 | + minval=self.height_lower, |
| 218 | + maxval=self.height_upper, |
| 219 | + dtype=tf.float32, |
| 220 | + ) |
| 221 | + width_translation = self._random_generator.random_uniform( |
| 222 | + shape=[batch_size, 1], |
| 223 | + minval=self.width_lower, |
| 224 | + maxval=self.width_upper, |
| 225 | + dtype=tf.float32, |
| 226 | + ) |
| 227 | + return { |
| 228 | + "height_translation": height_translation, |
| 229 | + "width_translation": width_translation, |
| 230 | + } |
| 231 | + |
| 232 | + def _batch_augment(self, inputs): |
| 233 | + # Change to vectorized_map for better performance, as well as work |
| 234 | + # around issue for different tensorspec between inputs and outputs. |
| 235 | + return tf.vectorized_map(self._augment, inputs) |
| 236 | + |
| 237 | + def augment_label(self, label, transformation, **kwargs): |
| 238 | + return label |
| 239 | + |
| 240 | + def compute_output_shape(self, input_shape): |
| 241 | + return input_shape |
| 242 | + |
| 243 | + def get_config(self): |
| 244 | + config = { |
| 245 | + "height_factor": self.height_factor, |
| 246 | + "width_factor": self.width_factor, |
| 247 | + "fill_mode": self.fill_mode, |
| 248 | + "fill_value": self.fill_value, |
| 249 | + "interpolation": self.interpolation, |
| 250 | + "seed": self.seed, |
| 251 | + } |
| 252 | + base_config = super().get_config() |
| 253 | + return dict(list(base_config.items()) + list(config.items())) |
0 commit comments