Skip to content

Commit 3388639

Browse files
authored
Implement RepeatedAugmentation as a KerasCV API (#1293)
* Implement RepeatedAugmentation as a KerasCV API more reading and fixes #372 * add test case * fix formatting * fix formatting * fix formatting * fix serialization test * add repeated augmentation usage docstring * Update component for repeated augment * Repeated augmentations fix * Test MixUp explicitly * update docstring * update docstring * Reformat * keras_cv/layers/preprocessing/repeated_augmentation.py
1 parent 3c572ec commit 3388639

File tree

5 files changed

+175
-0
lines changed

5 files changed

+175
-0
lines changed

keras_cv/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
from keras_cv.layers.preprocessing.random_shear import RandomShear
7272
from keras_cv.layers.preprocessing.random_zoom import RandomZoom
7373
from keras_cv.layers.preprocessing.randomly_zoomed_crop import RandomlyZoomedCrop
74+
from keras_cv.layers.preprocessing.repeated_augmentation import RepeatedAugmentation
7475
from keras_cv.layers.preprocessing.rescaling import Rescaling
7576
from keras_cv.layers.preprocessing.resizing import Resizing
7677
from keras_cv.layers.preprocessing.solarization import Solarization

keras_cv/layers/preprocessing/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from keras_cv.layers.preprocessing.random_shear import RandomShear
6262
from keras_cv.layers.preprocessing.random_zoom import RandomZoom
6363
from keras_cv.layers.preprocessing.randomly_zoomed_crop import RandomlyZoomedCrop
64+
from keras_cv.layers.preprocessing.repeated_augmentation import RepeatedAugmentation
6465
from keras_cv.layers.preprocessing.rescaling import Rescaling
6566
from keras_cv.layers.preprocessing.resizing import Resizing
6667
from keras_cv.layers.preprocessing.solarization import Solarization
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright 2023 The KerasCV Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import tensorflow as tf
15+
16+
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
17+
BaseImageAugmentationLayer,
18+
)
19+
20+
21+
@tf.keras.utils.register_keras_serializable(package="keras_cv")
22+
class RepeatedAugmentation(BaseImageAugmentationLayer):
23+
"""RepeatedAugmentation augments each image in a batch multiple times.
24+
25+
This technique exists to emulate the behavior of stochastic gradient descent within
26+
the context of mini-batch gradient descent. When training large vision models,
27+
choosing a large batch size can introduce too much noise into aggregated gradients
28+
causing the overall batch's gradients to be less effective than gradients produced
29+
using smaller gradients. RepeatedAugmentation handles this by re-using the same
30+
image multiple times within a batch creating correlated samples.
31+
32+
This layer increases your batch size by a factor of `len(augmenters)`.
33+
34+
Args:
35+
augmenters: the augmenters to use to augment the image
36+
shuffle: whether or not to shuffle the result. Essential when using an
37+
asynchronous distribution strategy such as ParameterServerStrategy.
38+
39+
Usage:
40+
41+
List of identical augmenters:
42+
```python
43+
repeated_augment = cv_layers.RepeatedAugmentation(
44+
augmenters=[cv_layers.RandAugment(value_range=(0, 255))] * 8
45+
)
46+
inputs = {
47+
"images": tf.ones((8, 512, 512, 3)),
48+
"labels": tf.ones((8,)),
49+
}
50+
outputs = repeated_augment(inputs)
51+
# outputs now has a batch size of 64 because there are 8 augmenters
52+
```
53+
54+
List of distinct augmenters:
55+
```python
56+
repeated_augment = cv_layers.RepeatedAugmentation(
57+
augmenters=[
58+
cv_layers.RandAugment(value_range=(0, 255)),
59+
cv_layers.RandomFlip(),
60+
]
61+
)
62+
inputs = {
63+
"images": tf.ones((8, 512, 512, 3)),
64+
"labels": tf.ones((8,)),
65+
}
66+
outputs = repeated_augment(inputs)
67+
```
68+
69+
References:
70+
- [DEIT implementaton](https://github.com/facebookresearch/deit/blob/ee8893c8063f6937fec7096e47ba324c206e22b9/samplers.py#L8)
71+
- [Original publication](https://openaccess.thecvf.com/content_CVPR_2020/papers/Hoffer_Augment_Your_Batch_Improving_Generalization_Through_Instance_Repetition_CVPR_2020_paper.pdf)
72+
73+
"""
74+
75+
def __init__(self, augmenters, shuffle=True, **kwargs):
76+
super().__init__(**kwargs)
77+
self.augmenters = augmenters
78+
self.shuffle = shuffle
79+
80+
def _batch_augment(self, inputs):
81+
if "bounding_boxes" in inputs:
82+
raise ValueError(
83+
"RepeatedAugmentation() does not yet support bounding box labels."
84+
)
85+
86+
augmenter_outputs = [augmenter(inputs) for augmenter in self.augmenters]
87+
88+
outputs = {}
89+
for k in inputs.keys():
90+
outputs[k] = tf.concat([output[k] for output in augmenter_outputs], axis=0)
91+
92+
if not self.shuffle:
93+
return outputs
94+
return self.shuffle_outputs(outputs)
95+
96+
def shuffle_outputs(self, result):
97+
indices = tf.range(start=0, limit=tf.shape(result["images"])[0], dtype=tf.int32)
98+
indices = tf.random.shuffle(indices)
99+
for key in result:
100+
result[key] = tf.gather(result[key], indices)
101+
return result
102+
103+
def _augment(self, inputs):
104+
raise ValueError(
105+
"RepeatedAugmentation() only works in batched mode. If "
106+
"you would like to create batches from a single image, use "
107+
"`x = tf.expand_dims(x, axis=0)` on your input images and labels."
108+
)
109+
110+
def get_config(self):
111+
config = super().get_config()
112+
config.update({"augmenters": self.augmenters, "shuffle": self.shuffle})
113+
return config
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright 2023 The KerasCV Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import tensorflow as tf
15+
16+
import keras_cv.layers as cv_layers
17+
18+
19+
class RepeatedAugmentationTest(tf.test.TestCase):
20+
def test_output_shapes(self):
21+
repeated_augment = cv_layers.RepeatedAugmentation(
22+
augmenters=[
23+
cv_layers.RandAugment(value_range=(0, 255)),
24+
cv_layers.RandomFlip(),
25+
]
26+
)
27+
inputs = {
28+
"images": tf.ones((8, 512, 512, 3)),
29+
"labels": tf.ones((8,)),
30+
}
31+
outputs = repeated_augment(inputs)
32+
33+
self.assertEqual(outputs["images"].shape, (16, 512, 512, 3))
34+
self.assertEqual(outputs["labels"].shape, (16,))
35+
36+
def test_with_mix_up(self):
37+
repeated_augment = cv_layers.RepeatedAugmentation(
38+
augmenters=[
39+
cv_layers.RandAugment(value_range=(0, 255)),
40+
cv_layers.MixUp(),
41+
]
42+
)
43+
inputs = {
44+
"images": tf.ones((8, 512, 512, 3)),
45+
"labels": tf.ones((8, 10)),
46+
}
47+
outputs = repeated_augment(inputs)
48+
49+
self.assertEqual(outputs["images"].shape, (16, 512, 512, 3))
50+
self.assertEqual(outputs["labels"].shape, (16, 10))

keras_cv/layers/serialization_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,16 @@ class SerializationTest(tf.test.TestCase, parameterized.TestCase):
9696
("GridMask", cv_layers.GridMask, {"seed": 1}),
9797
("MixUp", cv_layers.MixUp, {"seed": 1}),
9898
("Mosaic", cv_layers.Mosaic, {"seed": 1}),
99+
(
100+
"RepeatedAugmentation",
101+
cv_layers.RepeatedAugmentation,
102+
{
103+
"augmenters": [
104+
cv_layers.RandAugment(value_range=(0, 1)),
105+
cv_layers.RandomFlip(),
106+
]
107+
},
108+
),
99109
(
100110
"RandomChannelShift",
101111
cv_layers.RandomChannelShift,

0 commit comments

Comments
 (0)