@@ -23,14 +23,10 @@ class MixUp(BaseImagePreprocessingLayer):
23
23
Example:
24
24
```python
25
25
(images, labels), _ = keras.datasets.cifar10.load_data()
26
- images, labels = images[:10], labels[:10]
27
- # Labels must be floating-point and one-hot encoded
28
- labels = tf.cast(tf.one_hot(labels, 10), tf.float32)
29
- mixup = keras.layers.MixUp(alpha=0.2)
30
- augmented_images, updated_labels = mixup(
31
- {'images': images, 'labels': labels}
32
- )
33
- # output == {'images': updated_images, 'labels': updated_labels}
26
+ images, labels = images[:8], labels[:8]
27
+ labels = keras.ops.cast(keras.ops.one_hot(labels.flatten(), 10), "float32")
28
+ mix_up = keras.layers.MixUp(alpha=0.2)
29
+ output = mix_up({"images": images, "labels": labels})
34
30
```
35
31
"""
36
32
@@ -62,7 +58,7 @@ def get_random_transformation(self, data, training=True, seed=None):
62
58
)
63
59
64
60
mix_weight = self .backend .random .beta (
65
- (1 ,), self .alpha , self .alpha , seed = seed
61
+ (batch_size ,), self .alpha , self .alpha , seed = seed
66
62
)
67
63
return {
68
64
"mix_weight" : mix_weight ,
@@ -79,26 +75,26 @@ def transform_images(self, images, transformation=None, training=True):
79
75
dtype = self .compute_dtype ,
80
76
)
81
77
82
- mixup_images = self .backend .cast (
78
+ mix_up_images = self .backend .cast (
83
79
self .backend .numpy .take (images , permutation_order , axis = 0 ),
84
80
dtype = self .compute_dtype ,
85
81
)
86
82
87
- images = mix_weight * images + (1.0 - mix_weight ) * mixup_images
83
+ images = mix_weight * images + (1.0 - mix_weight ) * mix_up_images
88
84
89
85
return images
90
86
91
87
def transform_labels (self , labels , transformation , training = True ):
92
88
mix_weight = transformation ["mix_weight" ]
93
89
permutation_order = transformation ["permutation_order" ]
94
90
95
- labels_for_mixup = self .backend .numpy .take (
91
+ labels_for_mix_up = self .backend .numpy .take (
96
92
labels , permutation_order , axis = 0
97
93
)
98
94
99
95
mix_weight = self .backend .numpy .reshape (mix_weight , [- 1 , 1 ])
100
96
101
- labels = mix_weight * labels + (1.0 - mix_weight ) * labels_for_mixup
97
+ labels = mix_weight * labels + (1.0 - mix_weight ) * labels_for_mix_up
102
98
103
99
return labels
104
100
@@ -110,11 +106,11 @@ def transform_bounding_boxes(
110
106
):
111
107
permutation_order = transformation ["permutation_order" ]
112
108
boxes , classes = bounding_boxes ["boxes" ], bounding_boxes ["classes" ]
113
- boxes_for_mixup = self .backend .numpy .take (boxes , permutation_order )
114
- classes_for_mixup = self .backend .numpy .take (classes , permutation_order )
115
- boxes = self .backend .numpy .concat ([boxes , boxes_for_mixup ], axis = 1 )
109
+ boxes_for_mix_up = self .backend .numpy .take (boxes , permutation_order )
110
+ classes_for_mix_up = self .backend .numpy .take (classes , permutation_order )
111
+ boxes = self .backend .numpy .concat ([boxes , boxes_for_mix_up ], axis = 1 )
116
112
classes = self .backend .numpy .concat (
117
- [classes , classes_for_mixup ], axis = 1
113
+ [classes , classes_for_mix_up ], axis = 1
118
114
)
119
115
return {"boxes" : boxes , "classes" : classes }
120
116
@@ -126,13 +122,13 @@ def transform_segmentation_masks(
126
122
127
123
mix_weight = self .backend .numpy .reshape (mix_weight , [- 1 , 1 , 1 , 1 ])
128
124
129
- segmentation_masks_for_mixup = self .backend .numpy .take (
125
+ segmentation_masks_for_mix_up = self .backend .numpy .take (
130
126
segmentation_masks , permutation_order
131
127
)
132
128
133
129
segmentation_masks = (
134
130
mix_weight * segmentation_masks
135
- + (1.0 - mix_weight ) * segmentation_masks_for_mixup
131
+ + (1.0 - mix_weight ) * segmentation_masks_for_mix_up
136
132
)
137
133
138
134
return segmentation_masks
0 commit comments