Skip to content

Commit 129e3d7

Browse files
authored
Fix RandAugment behavior in tf graph mode (#21499)
* correct issue for rand_augment * Update the logic to allow replaceable * Update performance
1 parent 55263ca commit 129e3d7

File tree

2 files changed

+59
-14
lines changed

2 files changed

+59
-14
lines changed

keras/src/layers/preprocessing/image_preprocessing/rand_augment.py

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import random
2-
31
import keras.src.layers as layers
42
from keras.src.api_export import keras_export
53
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
@@ -169,9 +167,15 @@ def get_random_transformation(self, data, training=True, seed=None):
169167
augmentation_layer = getattr(self, layer_name)
170168
augmentation_layer.backend.set_backend("tensorflow")
171169

170+
layer_idxes = self.backend.random.randint(
171+
(self.num_ops,),
172+
0,
173+
len(self._AUGMENT_LAYERS),
174+
seed=self._get_seed_generator(self.backend._backend),
175+
)
176+
172177
transformation = {}
173-
random.shuffle(self._AUGMENT_LAYERS)
174-
for layer_name in self._AUGMENT_LAYERS[: self.num_ops]:
178+
for layer_name in self._AUGMENT_LAYERS:
175179
augmentation_layer = getattr(self, layer_name)
176180
transformation[layer_name] = (
177181
augmentation_layer.get_random_transformation(
@@ -181,17 +185,25 @@ def get_random_transformation(self, data, training=True, seed=None):
181185
)
182186
)
183187

184-
return transformation
188+
return {
189+
"transforms": transformation,
190+
"layer_idxes": layer_idxes,
191+
}
185192

186193
def transform_images(self, images, transformation, training=True):
187194
if training:
188195
images = self.backend.cast(images, self.compute_dtype)
189196

190-
for layer_name, transformation_value in transformation.items():
191-
augmentation_layer = getattr(self, layer_name)
192-
images = augmentation_layer.transform_images(
193-
images, transformation_value
194-
)
197+
layer_idxes = transformation["layer_idxes"]
198+
transforms = transformation["transforms"]
199+
for i in range(self.num_ops):
200+
for idx, (key, value) in enumerate(transforms.items()):
201+
augmentation_layer = getattr(self, key)
202+
images = self.backend.numpy.where(
203+
layer_idxes[i] == idx,
204+
augmentation_layer.transform_images(images, value),
205+
images,
206+
)
195207

196208
images = self.backend.cast(images, self.compute_dtype)
197209
return images
@@ -206,11 +218,29 @@ def transform_bounding_boxes(
206218
training=True,
207219
):
208220
if training:
209-
for layer_name, transformation_value in transformation.items():
210-
augmentation_layer = getattr(self, layer_name)
211-
bounding_boxes = augmentation_layer.transform_bounding_boxes(
212-
bounding_boxes, transformation_value, training=training
221+
layer_idxes = transformation["layer_idxes"]
222+
transforms = transformation["transforms"]
223+
for idx, (key, value) in enumerate(transforms.items()):
224+
augmentation_layer = getattr(self, key)
225+
226+
transformed_bounding_box = (
227+
augmentation_layer.transform_bounding_boxes(
228+
bounding_boxes.copy(), value
229+
)
213230
)
231+
for i in range(self.num_ops):
232+
bounding_boxes["boxes"] = self.backend.numpy.where(
233+
layer_idxes[i] == idx,
234+
transformed_bounding_box["boxes"],
235+
bounding_boxes["boxes"],
236+
)
237+
238+
bounding_boxes["labels"] = self.backend.numpy.where(
239+
layer_idxes[i] == idx,
240+
transformed_bounding_box["labels"],
241+
bounding_boxes["labels"],
242+
)
243+
214244
return bounding_boxes
215245

216246
def transform_segmentation_masks(

keras/src/layers/preprocessing/image_preprocessing/rand_augment_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,18 @@ def test_rand_augment_tf_data_bounding_boxes(self):
112112
bounding_box_format="xyxy",
113113
)
114114
ds.map(layer)
115+
116+
def test_graph_issue(self):
117+
input_data = np.random.random((10, 8, 8, 3))
118+
layer = layers.RandAugment()
119+
ds = (
120+
tf_data.Dataset.from_tensor_slices(input_data)
121+
.batch(2)
122+
.map(lambda x: layer.get_random_transformation(x))
123+
)
124+
125+
key_list = []
126+
for output in ds:
127+
key_list.append(output["layer_idxes"])
128+
129+
self.assertNotEquals(len(np.unique(key_list)), 1)

0 commit comments

Comments
 (0)