3
3
BaseImagePreprocessingLayer ,
4
4
)
5
5
from keras .src .random import SeedGenerator
6
+ from keras .src .utils import backend_utils
6
7
7
8
8
9
@keras_export ("keras.layers.MixUp" )
@@ -66,36 +67,40 @@ def get_random_transformation(self, data, training=True, seed=None):
66
67
}
67
68
68
69
def transform_images (self , images , transformation = None , training = True ):
69
- images = self .backend .cast (images , self .compute_dtype )
70
- mix_weight = transformation ["mix_weight" ]
71
- permutation_order = transformation ["permutation_order" ]
72
-
73
- mix_weight = self .backend .cast (
74
- self .backend .numpy .reshape (mix_weight , [- 1 , 1 , 1 , 1 ]),
75
- dtype = self .compute_dtype ,
76
- )
77
-
78
- mix_up_images = self .backend .cast (
79
- self .backend .numpy .take (images , permutation_order , axis = 0 ),
80
- dtype = self .compute_dtype ,
81
- )
82
-
83
- images = mix_weight * images + (1.0 - mix_weight ) * mix_up_images
84
-
70
+ def _mix_up_input (images , transformation ):
71
+ images = self .backend .cast (images , self .compute_dtype )
72
+ mix_weight = transformation ["mix_weight" ]
73
+ permutation_order = transformation ["permutation_order" ]
74
+ mix_weight = self .backend .cast (
75
+ self .backend .numpy .reshape (mix_weight , [- 1 , 1 , 1 , 1 ]),
76
+ dtype = self .compute_dtype ,
77
+ )
78
+ mix_up_images = self .backend .cast (
79
+ self .backend .numpy .take (images , permutation_order , axis = 0 ),
80
+ dtype = self .compute_dtype ,
81
+ )
82
+ images = mix_weight * images + (1.0 - mix_weight ) * mix_up_images
83
+ return images
84
+
85
+ if training :
86
+ images = _mix_up_input (images , transformation )
85
87
return images
86
88
87
89
def transform_labels (self , labels , transformation , training = True ):
88
- mix_weight = transformation ["mix_weight" ]
89
- permutation_order = transformation ["permutation_order" ]
90
-
91
- labels_for_mix_up = self .backend .numpy .take (
92
- labels , permutation_order , axis = 0
93
- )
94
-
95
- mix_weight = self .backend .numpy .reshape (mix_weight , [- 1 , 1 ])
96
-
97
- labels = mix_weight * labels + (1.0 - mix_weight ) * labels_for_mix_up
98
-
90
+ def _mix_up_labels (labels , transformation ):
91
+ mix_weight = transformation ["mix_weight" ]
92
+ permutation_order = transformation ["permutation_order" ]
93
+ labels_for_mix_up = self .backend .numpy .take (
94
+ labels , permutation_order , axis = 0
95
+ )
96
+ mix_weight = self .backend .numpy .reshape (mix_weight , [- 1 , 1 ])
97
+ labels = (
98
+ mix_weight * labels + (1.0 - mix_weight ) * labels_for_mix_up
99
+ )
100
+ return labels
101
+
102
+ if training :
103
+ labels = _mix_up_labels (labels , transformation )
99
104
return labels
100
105
101
106
def transform_bounding_boxes (
@@ -104,33 +109,58 @@ def transform_bounding_boxes(
104
109
transformation ,
105
110
training = True ,
106
111
):
107
- permutation_order = transformation ["permutation_order" ]
108
- boxes , classes = bounding_boxes ["boxes" ], bounding_boxes ["classes" ]
109
- boxes_for_mix_up = self .backend .numpy .take (boxes , permutation_order )
110
- classes_for_mix_up = self .backend .numpy .take (classes , permutation_order )
111
- boxes = self .backend .numpy .concat ([boxes , boxes_for_mix_up ], axis = 1 )
112
- classes = self .backend .numpy .concat (
113
- [classes , classes_for_mix_up ], axis = 1
114
- )
115
- return {"boxes" : boxes , "classes" : classes }
112
+ def _mix_up_bounding_boxes (bounding_boxes , transformation ):
113
+ if backend_utils .in_tf_graph ():
114
+ self .backend .set_backend ("tensorflow" )
116
115
117
- def transform_segmentation_masks (
118
- self , segmentation_masks , transformation , training = True
119
- ):
120
- mix_weight = transformation ["mix_weight" ]
121
- permutation_order = transformation ["permutation_order" ]
116
+ permutation_order = transformation ["permutation_order" ]
122
117
123
- mix_weight = self .backend .numpy .reshape (mix_weight , [- 1 , 1 , 1 , 1 ])
118
+ boxes , labels = bounding_boxes ["boxes" ], bounding_boxes ["labels" ]
119
+ boxes_for_mix_up = self .backend .numpy .take (
120
+ boxes , permutation_order , axis = 0
121
+ )
124
122
125
- segmentation_masks_for_mix_up = self .backend .numpy .take (
126
- segmentation_masks , permutation_order
127
- )
123
+ labels_for_mix_up = self .backend .numpy .take (
124
+ labels , permutation_order , axis = 0
125
+ )
126
+ boxes = self .backend .numpy .concatenate (
127
+ [boxes , boxes_for_mix_up ], axis = 1
128
+ )
128
129
129
- segmentation_masks = (
130
- mix_weight * segmentation_masks
131
- + (1.0 - mix_weight ) * segmentation_masks_for_mix_up
132
- )
130
+ labels = self .backend .numpy .concatenate (
131
+ [labels , labels_for_mix_up ], axis = 0
132
+ )
133
+
134
+ self .backend .reset ()
133
135
136
+ return {"boxes" : boxes , "labels" : labels }
137
+
138
+ if training :
139
+ bounding_boxes = _mix_up_bounding_boxes (
140
+ bounding_boxes , transformation
141
+ )
142
+ return bounding_boxes
143
+
144
+ def transform_segmentation_masks (
145
+ self , segmentation_masks , transformation , training = True
146
+ ):
147
+ def _mix_up_segmentation_masks (segmentation_masks , transformation ):
148
+ mix_weight = transformation ["mix_weight" ]
149
+ permutation_order = transformation ["permutation_order" ]
150
+ mix_weight = self .backend .numpy .reshape (mix_weight , [- 1 , 1 , 1 , 1 ])
151
+ segmentation_masks_for_mix_up = self .backend .numpy .take (
152
+ segmentation_masks , permutation_order
153
+ )
154
+ segmentation_masks = (
155
+ mix_weight * segmentation_masks
156
+ + (1.0 - mix_weight ) * segmentation_masks_for_mix_up
157
+ )
158
+ return segmentation_masks
159
+
160
+ if training :
161
+ segmentation_masks = _mix_up_segmentation_masks (
162
+ segmentation_masks , transformation
163
+ )
134
164
return segmentation_masks
135
165
136
166
def compute_output_shape (self , input_shape ):
0 commit comments