Skip to content

Commit 6069b15

Browse files
author
(Ian Stenbit)
committed
Removes label smoothing from CutMix and MixUp (keras-team#70)
* Removes label smoothing from CutMix and MixUp * Run format
1 parent ea41f76 commit 6069b15

File tree

4 files changed

+11
-65
lines changed

4 files changed

+11
-65
lines changed

keras_cv/layers/preprocessing/cut_mix.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,6 @@ class CutMix(layers.Layer):
2828
distribution. This controls the shape of the distribution from which the
2929
smoothing values are sampled. Defaults 1.0, which is a recommended value
3030
when training an imagenet1k classification model.
31-
label_smoothing: Float in [0, 1]. When > 0, label values are smoothed,
32-
meaning the confidence on label values are relaxed. e.g.
33-
label_smoothing=0.2 means that we will use a value of 0.1 for label 0 and
34-
0.9 for label 1. Defaults 0.0.
3531
References:
3632
[CutMix paper]( https://arxiv.org/abs/1905.04899).
3733
@@ -43,11 +39,10 @@ class CutMix(layers.Layer):
4339
```
4440
"""
4541

46-
def __init__(self, rate, label_smoothing=0.0, alpha=1.0, seed=None, **kwargs):
42+
def __init__(self, rate, alpha=1.0, seed=None, **kwargs):
4743
super().__init__(**kwargs)
4844
self.alpha = alpha
4945
self.rate = rate
50-
self.label_smoothing = label_smoothing
5146
self.seed = seed
5247

5348
@staticmethod
@@ -85,7 +80,7 @@ def call(self, images, labels, training=True):
8580
augment_cond = tf.logical_and(rate_cond, training)
8681
# pylint: disable=g-long-lambda
8782
cutmix_augment = lambda: self._update_labels(*self._cutmix(images, labels))
88-
no_augment = lambda: (images, self._smooth_labels(labels))
83+
no_augment = lambda: (images, labels)
8984
return tf.cond(augment_cond, cutmix_augment, no_augment)
9085

9186
def _cutmix(self, images, labels):
@@ -132,15 +127,8 @@ def _cutmix(self, images, labels):
132127
return images, labels, lambda_sample, permutation_order
133128

134129
def _update_labels(self, images, labels, lambda_sample, permutation_order):
135-
labels_smoothed = self._smooth_labels(labels)
136130
cutout_labels = tf.gather(labels, permutation_order)
137131

138132
lambda_sample = tf.reshape(lambda_sample, [-1, 1])
139-
labels = lambda_sample * labels_smoothed + (1.0 - lambda_sample) * cutout_labels
133+
labels = lambda_sample * labels + (1.0 - lambda_sample) * cutout_labels
140134
return images, labels
141-
142-
def _smooth_labels(self, labels):
143-
label_smoothing = self.label_smoothing or 0.0
144-
off_value = label_smoothing / tf.cast(tf.shape(labels)[1], tf.float32)
145-
on_value = 1.0 - label_smoothing + off_value
146-
return labels * on_value + (1 - labels) * off_value

keras_cv/layers/preprocessing/cut_mix_test.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,7 @@ def test_return_shapes(self):
3030
xs, ys = layer(xs, ys)
3131

3232
self.assertEqual(xs.shape, [2, 512, 512, 3])
33-
# one hot smoothed labels
3433
self.assertEqual(ys.shape, [2, 10])
35-
self.assertEqual(len(ys != 0.0), 2)
36-
37-
def test_label_smoothing(self):
38-
xs = tf.ones((2, 512, 512, 3))
39-
# randomly sample labels
40-
ys = tf.random.categorical(tf.math.log([[0.5, 0.5]]), 2)
41-
ys = tf.squeeze(ys)
42-
ys = tf.one_hot(ys, NUM_CLASSES)
43-
44-
layer = CutMix(1.0, label_smoothing=0.2, seed=1)
45-
xs, ys = layer(xs, ys)
46-
self.assertNotAllClose(ys, 0.0)
47-
self.assertAllClose(tf.math.reduce_sum(ys, axis=-1), (1.0, 1.0))
4834

4935
def test_cut_mix_call_results(self):
5036
xs = tf.cast(
@@ -56,7 +42,7 @@ def test_cut_mix_call_results(self):
5642
)
5743
ys = tf.one_hot(tf.constant([0, 1]), 2)
5844

59-
layer = CutMix(1.0, label_smoothing=0.0, seed=1)
45+
layer = CutMix(1.0, seed=1)
6046
xs, ys = layer(xs, ys)
6147

6248
# At least some pixels should be replaced in the CutMix operation
@@ -78,7 +64,7 @@ def test_cut_mix_call_results_one_channel(self):
7864
)
7965
ys = tf.one_hot(tf.constant([0, 1]), 2)
8066

81-
layer = CutMix(1.0, label_smoothing=0.0, seed=1)
67+
layer = CutMix(1.0, seed=1)
8268
xs, ys = layer(xs, ys)
8369

8470
# At least some pixels should be replaced in the CutMix operation
@@ -97,7 +83,7 @@ def test_in_tf_function(self):
9783
)
9884
ys = tf.one_hot(tf.constant([0, 1]), 2)
9985

100-
layer = CutMix(1.0, label_smoothing=0.0, seed=1)
86+
layer = CutMix(1.0, seed=1)
10187

10288
@tf.function
10389
def augment(x, y):

keras_cv/layers/preprocessing/mix_up.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,6 @@ class MixUp(layers.Layer):
2626
distribution. This controls the shape of the distribution from which the
2727
smoothing values are sampled. Defaults 0.2, which is a recommended value
2828
when training an imagenet1k classification model.
29-
label_smoothing: Float in [0, 1]. When > 0, label values are smoothed,
30-
meaning the confidence on label values are relaxed. e.g.
31-
label_smoothing=0.2 means that we will use a value of 0.1 for label 0 and
32-
0.9 for label 1. Defaults 0.0.
3329
References:
3430
[MixUp paper](https://arxiv.org/abs/1710.09412).
3531
@@ -41,11 +37,10 @@ class MixUp(layers.Layer):
4137
```
4238
"""
4339

44-
def __init__(self, rate, label_smoothing=0.0, alpha=0.2, seed=None, **kwargs):
40+
def __init__(self, rate, alpha=0.2, seed=None, **kwargs):
4541
super().__init__(**kwargs)
4642
self.alpha = alpha
4743
self.rate = rate
48-
self.label_smoothing = label_smoothing
4944
self.seed = seed
5045

5146
@staticmethod
@@ -83,7 +78,7 @@ def call(self, images, labels, training=True):
8378
augment_cond = tf.logical_and(rate_cond, training)
8479
# pylint: disable=g-long-lambda
8580
mixup_augment = lambda: self._update_labels(*self._mixup(images, labels))
86-
no_augment = lambda: (images, self._smooth_labels(labels))
81+
no_augment = lambda: (images, labels)
8782
return tf.cond(augment_cond, mixup_augment, no_augment)
8883

8984
def _mixup(self, images, labels):
@@ -99,18 +94,9 @@ def _mixup(self, images, labels):
9994
return images, labels, tf.squeeze(lambda_sample), permutation_order
10095

10196
def _update_labels(self, images, labels, lambda_sample, permutation_order):
102-
labels_smoothed = self._smooth_labels(labels)
10397
labels_for_mixup = tf.gather(labels, permutation_order)
10498

10599
lambda_sample = tf.reshape(lambda_sample, [-1, 1])
106-
labels = (
107-
lambda_sample * labels_smoothed + (1.0 - lambda_sample) * labels_for_mixup
108-
)
100+
labels = lambda_sample * labels + (1.0 - lambda_sample) * labels_for_mixup
109101

110102
return images, labels
111-
112-
def _smooth_labels(self, labels):
113-
label_smoothing = self.label_smoothing or 0.0
114-
off_value = label_smoothing / tf.cast(tf.shape(labels)[1], tf.float32)
115-
on_value = 1.0 - label_smoothing + off_value
116-
return on_value * labels + (1 - labels) * off_value

keras_cv/layers/preprocessing/mix_up_test.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,7 @@ def test_return_shapes(self):
3030
xs, ys = layer(xs, ys)
3131

3232
self.assertEqual(xs.shape, [2, 512, 512, 3])
33-
# one hot smoothed labels
3433
self.assertEqual(ys.shape, [2, 10])
35-
self.assertEqual(len(ys != 0.0), 2)
36-
37-
def test_label_smoothing(self):
38-
xs = tf.ones((2, 512, 512, 3))
39-
# randomly sample labels
40-
ys = tf.random.categorical(tf.math.log([[0.5, 0.5]]), 2)
41-
ys = tf.squeeze(ys)
42-
ys = tf.one_hot(ys, NUM_CLASSES)
43-
44-
layer = MixUp(1.0, label_smoothing=0.2)
45-
xs, ys = layer(xs, ys)
46-
self.assertNotAllClose(ys, 0.0)
47-
self.assertAllClose(tf.math.reduce_sum(ys, axis=-1), (1.0, 1.0))
4834

4935
def test_mix_up_call_results(self):
5036
xs = tf.cast(
@@ -56,7 +42,7 @@ def test_mix_up_call_results(self):
5642
)
5743
ys = tf.one_hot(tf.constant([0, 1]), 2)
5844

59-
layer = MixUp(1.0, label_smoothing=0.0)
45+
layer = MixUp(1.0)
6046
xs, ys = layer(xs, ys)
6147

6248
# None of the individual values should still be close to 1 or 0
@@ -77,7 +63,7 @@ def test_in_tf_function(self):
7763
)
7864
ys = tf.one_hot(tf.constant([0, 1]), 2)
7965

80-
layer = MixUp(1.0, label_smoothing=0.0)
66+
layer = MixUp(1.0)
8167

8268
@tf.function
8369
def augment(x, y):

0 commit comments

Comments
 (0)