Skip to content

Commit d508f48

Browse files
chjortLukeWood
andauthored
Implementation of Cutout layer and RandomErasing layer (#47)
* initiate CutOut layer * Implement CutOut layer * docstring and rename * initiate RandomErase layer * Implement RandomErase layer * refactor CutOut and RandomErase to inherit from same base class * docstring * remove unnecessary loop * rename to noun instead of verb * rename to noun instead of verb * rename to noun instead of verb * add get_config methods * rename to noun instead of verb * rename to noun instead of verb * remove unnecessary check * minor changes requested * refactor fill_rectangle into a separate utils file * import module instead of function * refactor CutOut layer * refactor RandomErasing layer * add CutOut, RandomErasing to __init__.py * add training argument to layer call * rearrange imports * remove labels argument from CutOut and RandomErasing * add training argument to MixUp and CutMix * update argument names for RandomErasing, and more detailed docstring. * docstring fix * docstring fix * rearrange imports * linting * add license * fix docstring * fix docstring * fix docstring * refactor into single layer * remove old layer * refactor to make layer more extendable * linting and additional tests * black * edit comment * ceil height and width * Update random_cutout to default to a 0 for the lower bound * Remove unneeded if statement Co-authored-by: Luke Wood <[email protected]>
1 parent c7bf271 commit d508f48

File tree

7 files changed

+539
-50
lines changed

7 files changed

+539
-50
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""random_cutout_demo.py shows how to use the RandomCutout preprocessing layer.
2+
3+
Operates on the oxford_flowers102 dataset. In this script the flowers
4+
are loaded, then are passed through the preprocessing layers.
5+
Finally, they are shown using matplotlib.
6+
"""
7+
import matplotlib.pyplot as plt
8+
import tensorflow as tf
9+
import tensorflow_datasets as tfds
10+
11+
from keras_cv.layers import preprocessing
12+
13+
IMG_SIZE = (224, 224)
14+
BATCH_SIZE = 64
15+
16+
17+
def resize(image, label, num_classes=10):
18+
image = tf.image.resize(image, IMG_SIZE)
19+
label = tf.one_hot(label, num_classes)
20+
return image, label
21+
22+
23+
def main():
24+
data, ds_info = tfds.load(
25+
"oxford_flowers102", with_info=True, as_supervised=True
26+
)
27+
train_ds = data["train"]
28+
29+
num_classes = ds_info.features["label"].num_classes
30+
31+
train_ds = (
32+
train_ds.map(lambda x, y: resize(x, y, num_classes=num_classes))
33+
.shuffle(10 * BATCH_SIZE)
34+
.batch(BATCH_SIZE)
35+
)
36+
random_cutout = preprocessing.RandomCutout(
37+
height_factor=(0.3, 0.9),
38+
width_factor=64,
39+
fill_mode="gaussian_noise",
40+
rate=1.0,
41+
)
42+
train_ds = train_ds.map(
43+
lambda x, y: (random_cutout(x), y), num_parallel_calls=tf.data.AUTOTUNE
44+
)
45+
46+
for images, labels in train_ds.take(1):
47+
plt.figure(figsize=(8, 8))
48+
for i in range(9):
49+
plt.subplot(3, 3, i + 1)
50+
plt.imshow(images[i].numpy().astype("uint8"))
51+
plt.axis("off")
52+
plt.show()
53+
54+
55+
if __name__ == "__main__":
56+
main()

keras_cv/layers/preprocessing/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414

1515
from keras_cv.layers.preprocessing.cut_mix import CutMix
1616
from keras_cv.layers.preprocessing.mix_up import MixUp
17+
from keras_cv.layers.preprocessing.random_cutout import RandomCutout

keras_cv/layers/preprocessing/cut_mix.py

Lines changed: 9 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
import tensorflow as tf
1515
import tensorflow.keras.layers as layers
1616
from absl import logging
17+
from tensorflow.keras import backend
18+
19+
from keras_cv.utils import fill_utils
1720

1821

1922
class CutMix(layers.Layer):
@@ -53,7 +56,7 @@ def _sample_from_beta(alpha, beta, shape):
5356
sample_beta = tf.random.gamma(shape, 1.0, beta=beta)
5457
return sample_alpha / (sample_alpha + sample_beta)
5558

56-
def call(self, images, labels):
59+
def call(self, images, labels, training=True):
5760
"""call method for the CutMix layer.
5861
5962
Args:
@@ -66,6 +69,8 @@ def call(self, images, labels):
6669
labels: updated labels with both label smoothing and the cutmix updates
6770
applied.
6871
"""
72+
if training is None:
73+
training = backend.learning_phase()
6974

7075
if tf.shape(images)[0] == 1:
7176
logging.warning(
@@ -74,9 +79,10 @@ def call(self, images, labels):
7479
"expected. Please call the layer with 2 or more samples."
7580
)
7681

77-
augment_cond = tf.less(
82+
rate_cond = tf.less(
7883
tf.random.uniform(shape=[], minval=0.0, maxval=1.0), self.rate
7984
)
85+
augment_cond = tf.logical_and(rate_cond, training)
8086
# pylint: disable=g-long-lambda
8187
cutmix_augment = lambda: self._update_labels(*self._cutmix(images, labels))
8288
no_augment = lambda: (images, self._smooth_labels(labels))
@@ -119,7 +125,7 @@ def _cutmix(self, images, labels):
119125
lambda_sample = tf.cast(lambda_sample, dtype=tf.float32)
120126

121127
images = tf.map_fn(
122-
lambda x: _fill_rectangle(*x),
128+
lambda x: fill_utils.fill_rectangle(*x),
123129
(
124130
images,
125131
random_center_width,
@@ -148,48 +154,3 @@ def _smooth_labels(self, labels):
148154
off_value = label_smoothing / tf.cast(tf.shape(labels)[1], tf.float32)
149155
on_value = 1.0 - label_smoothing + off_value
150156
return labels * on_value + (1 - labels) * off_value
151-
152-
153-
def _fill_rectangle(
154-
image, center_width, center_height, half_width, half_height, replace=None
155-
):
156-
"""Fill a rectangle in a given image using the value provided in replace.
157-
158-
Args:
159-
image: the starting image to fill the rectangle on.
160-
center_width: the X center of the rectangle to fill
161-
center_height: the Y center of the rectangle to fill
162-
half_width: 1/2 the width of the resulting rectangle
163-
half_height: 1/2 the height of the resulting rectangle
164-
replace: The value to fill the rectangle with. Accepts a Tensor,
165-
Constant, or None.
166-
Returns:
167-
image: the modified image with the chosen rectangle filled.
168-
"""
169-
image_shape = tf.shape(image)
170-
image_height = image_shape[0]
171-
image_width = image_shape[1]
172-
173-
lower_pad = tf.maximum(0, center_height - half_height)
174-
upper_pad = tf.maximum(0, image_height - center_height - half_height)
175-
left_pad = tf.maximum(0, center_width - half_width)
176-
right_pad = tf.maximum(0, image_width - center_width - half_width)
177-
178-
cutout_shape = [
179-
image_height - (lower_pad + upper_pad),
180-
image_width - (left_pad + right_pad),
181-
]
182-
padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]]
183-
mask = tf.pad(
184-
tf.zeros(cutout_shape, dtype=image.dtype), padding_dims, constant_values=1
185-
)
186-
mask = tf.expand_dims(mask, -1)
187-
188-
if replace is None:
189-
fill = tf.random.normal(image_shape, dtype=image.dtype)
190-
elif isinstance(replace, tf.Tensor):
191-
fill = replace
192-
else:
193-
fill = tf.ones_like(image, dtype=image.dtype) * replace
194-
image = tf.where(tf.equal(mask, 0), fill, image)
195-
return image

keras_cv/layers/preprocessing/mix_up.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import tensorflow as tf
1515
import tensorflow.keras.layers as layers
1616
from absl import logging
17+
from tensorflow.keras import backend
1718

1819

1920
class MixUp(layers.Layer):
@@ -53,7 +54,7 @@ def _sample_from_beta(alpha, beta, shape):
5354
sample_beta = tf.random.gamma(shape, 1.0, beta=beta)
5455
return sample_alpha / (sample_alpha + sample_beta)
5556

56-
def call(self, images, labels):
57+
def call(self, images, labels, training=True):
5758
"""call method for the MixUp layer.
5859
5960
Args:
@@ -66,6 +67,8 @@ def call(self, images, labels):
6667
labels: updated labels with both label smoothing and the cutmix updates
6768
applied.
6869
"""
70+
if training is None:
71+
training = backend.learning_phase()
6972

7073
if tf.shape(images)[0] == 1:
7174
logging.warning(
@@ -74,9 +77,10 @@ def call(self, images, labels):
7477
"expected. Please call the layer with 2 or more samples."
7578
)
7679

77-
augment_cond = tf.less(
80+
rate_cond = tf.less(
7881
tf.random.uniform(shape=[], minval=0.0, maxval=1.0), self.rate
7982
)
83+
augment_cond = tf.logical_and(rate_cond, training)
8084
# pylint: disable=g-long-lambda
8185
mixup_augment = lambda: self._update_labels(*self._mixup(images, labels))
8286
no_augment = lambda: (images, self._smooth_labels(labels))

0 commit comments

Comments
 (0)