Skip to content

Commit 4aa6a67

Browse files
authored
Update example and logic for mix_up (#20642)
* Update example and logic for mix_up * remove tf from example
1 parent 7a16b8e commit 4aa6a67

File tree

1 file changed

+15
-19
lines changed
  • keras/src/layers/preprocessing/image_preprocessing

1 file changed

+15
-19
lines changed

keras/src/layers/preprocessing/image_preprocessing/mix_up.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,10 @@ class MixUp(BaseImagePreprocessingLayer):
2323
Example:
2424
```python
2525
(images, labels), _ = keras.datasets.cifar10.load_data()
26-
images, labels = images[:10], labels[:10]
27-
# Labels must be floating-point and one-hot encoded
28-
labels = tf.cast(tf.one_hot(labels, 10), tf.float32)
29-
mixup = keras.layers.MixUp(alpha=0.2)
30-
augmented_images, updated_labels = mixup(
31-
{'images': images, 'labels': labels}
32-
)
33-
# output == {'images': updated_images, 'labels': updated_labels}
26+
images, labels = images[:8], labels[:8]
27+
labels = keras.ops.cast(keras.ops.one_hot(labels.flatten(), 10), "float32")
28+
mix_up = keras.layers.MixUp(alpha=0.2)
29+
output = mix_up({"images": images, "labels": labels})
3430
```
3531
"""
3632

@@ -62,7 +58,7 @@ def get_random_transformation(self, data, training=True, seed=None):
6258
)
6359

6460
mix_weight = self.backend.random.beta(
65-
(1,), self.alpha, self.alpha, seed=seed
61+
(batch_size,), self.alpha, self.alpha, seed=seed
6662
)
6763
return {
6864
"mix_weight": mix_weight,
@@ -79,26 +75,26 @@ def transform_images(self, images, transformation=None, training=True):
7975
dtype=self.compute_dtype,
8076
)
8177

82-
mixup_images = self.backend.cast(
78+
mix_up_images = self.backend.cast(
8379
self.backend.numpy.take(images, permutation_order, axis=0),
8480
dtype=self.compute_dtype,
8581
)
8682

87-
images = mix_weight * images + (1.0 - mix_weight) * mixup_images
83+
images = mix_weight * images + (1.0 - mix_weight) * mix_up_images
8884

8985
return images
9086

9187
def transform_labels(self, labels, transformation, training=True):
9288
mix_weight = transformation["mix_weight"]
9389
permutation_order = transformation["permutation_order"]
9490

95-
labels_for_mixup = self.backend.numpy.take(
91+
labels_for_mix_up = self.backend.numpy.take(
9692
labels, permutation_order, axis=0
9793
)
9894

9995
mix_weight = self.backend.numpy.reshape(mix_weight, [-1, 1])
10096

101-
labels = mix_weight * labels + (1.0 - mix_weight) * labels_for_mixup
97+
labels = mix_weight * labels + (1.0 - mix_weight) * labels_for_mix_up
10298

10399
return labels
104100

@@ -110,11 +106,11 @@ def transform_bounding_boxes(
110106
):
111107
permutation_order = transformation["permutation_order"]
112108
boxes, classes = bounding_boxes["boxes"], bounding_boxes["classes"]
113-
boxes_for_mixup = self.backend.numpy.take(boxes, permutation_order)
114-
classes_for_mixup = self.backend.numpy.take(classes, permutation_order)
115-
boxes = self.backend.numpy.concat([boxes, boxes_for_mixup], axis=1)
109+
boxes_for_mix_up = self.backend.numpy.take(boxes, permutation_order)
110+
classes_for_mix_up = self.backend.numpy.take(classes, permutation_order)
111+
boxes = self.backend.numpy.concat([boxes, boxes_for_mix_up], axis=1)
116112
classes = self.backend.numpy.concat(
117-
[classes, classes_for_mixup], axis=1
113+
[classes, classes_for_mix_up], axis=1
118114
)
119115
return {"boxes": boxes, "classes": classes}
120116

@@ -126,13 +122,13 @@ def transform_segmentation_masks(
126122

127123
mix_weight = self.backend.numpy.reshape(mix_weight, [-1, 1, 1, 1])
128124

129-
segmentation_masks_for_mixup = self.backend.numpy.take(
125+
segmentation_masks_for_mix_up = self.backend.numpy.take(
130126
segmentation_masks, permutation_order
131127
)
132128

133129
segmentation_masks = (
134130
mix_weight * segmentation_masks
135-
+ (1.0 - mix_weight) * segmentation_masks_for_mixup
131+
+ (1.0 - mix_weight) * segmentation_masks_for_mix_up
136132
)
137133

138134
return segmentation_masks

0 commit comments

Comments
 (0)