1
- import random
2
-
3
1
import keras .src .layers as layers
4
2
from keras .src .api_export import keras_export
5
3
from keras .src .layers .preprocessing .image_preprocessing .base_image_preprocessing_layer import ( # noqa: E501
@@ -169,9 +167,15 @@ def get_random_transformation(self, data, training=True, seed=None):
169
167
augmentation_layer = getattr (self , layer_name )
170
168
augmentation_layer .backend .set_backend ("tensorflow" )
171
169
170
+ layer_idxes = self .backend .random .randint (
171
+ (self .num_ops ,),
172
+ 0 ,
173
+ len (self ._AUGMENT_LAYERS ),
174
+ seed = self ._get_seed_generator (self .backend ._backend ),
175
+ )
176
+
172
177
transformation = {}
173
- random .shuffle (self ._AUGMENT_LAYERS )
174
- for layer_name in self ._AUGMENT_LAYERS [: self .num_ops ]:
178
+ for layer_name in self ._AUGMENT_LAYERS :
175
179
augmentation_layer = getattr (self , layer_name )
176
180
transformation [layer_name ] = (
177
181
augmentation_layer .get_random_transformation (
@@ -181,17 +185,25 @@ def get_random_transformation(self, data, training=True, seed=None):
181
185
)
182
186
)
183
187
184
- return transformation
188
+ return {
189
+ "transforms" : transformation ,
190
+ "layer_idxes" : layer_idxes ,
191
+ }
185
192
186
193
def transform_images (self , images , transformation , training = True ):
187
194
if training :
188
195
images = self .backend .cast (images , self .compute_dtype )
189
196
190
- for layer_name , transformation_value in transformation .items ():
191
- augmentation_layer = getattr (self , layer_name )
192
- images = augmentation_layer .transform_images (
193
- images , transformation_value
194
- )
197
+ layer_idxes = transformation ["layer_idxes" ]
198
+ transforms = transformation ["transforms" ]
199
+ for i in range (self .num_ops ):
200
+ for idx , (key , value ) in enumerate (transforms .items ()):
201
+ augmentation_layer = getattr (self , key )
202
+ images = self .backend .numpy .where (
203
+ layer_idxes [i ] == idx ,
204
+ augmentation_layer .transform_images (images , value ),
205
+ images ,
206
+ )
195
207
196
208
images = self .backend .cast (images , self .compute_dtype )
197
209
return images
@@ -206,11 +218,29 @@ def transform_bounding_boxes(
206
218
training = True ,
207
219
):
208
220
if training :
209
- for layer_name , transformation_value in transformation .items ():
210
- augmentation_layer = getattr (self , layer_name )
211
- bounding_boxes = augmentation_layer .transform_bounding_boxes (
212
- bounding_boxes , transformation_value , training = training
221
+ layer_idxes = transformation ["layer_idxes" ]
222
+ transforms = transformation ["transforms" ]
223
+ for idx , (key , value ) in enumerate (transforms .items ()):
224
+ augmentation_layer = getattr (self , key )
225
+
226
+ transformed_bounding_box = (
227
+ augmentation_layer .transform_bounding_boxes (
228
+ bounding_boxes .copy (), value
229
+ )
213
230
)
231
+ for i in range (self .num_ops ):
232
+ bounding_boxes ["boxes" ] = self .backend .numpy .where (
233
+ layer_idxes [i ] == idx ,
234
+ transformed_bounding_box ["boxes" ],
235
+ bounding_boxes ["boxes" ],
236
+ )
237
+
238
+ bounding_boxes ["labels" ] = self .backend .numpy .where (
239
+ layer_idxes [i ] == idx ,
240
+ transformed_bounding_box ["labels" ],
241
+ bounding_boxes ["labels" ],
242
+ )
243
+
214
244
return bounding_boxes
215
245
216
246
def transform_segmentation_masks (
0 commit comments