Skip to content

Commit 4c7c4b5

Browse files
authored
Add a condition to verify training status during image processing (#20650)
* Add a condition to verify training status during image processing * resolve merge conflict * fix transform_bounding_boxes logic * add transform_bounding_boxes test
1 parent ed1442e commit 4c7c4b5

File tree

4 files changed

+220
-84
lines changed

4 files changed

+220
-84
lines changed

keras/src/layers/preprocessing/image_preprocessing/mix_up.py

Lines changed: 79 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
BaseImagePreprocessingLayer,
44
)
55
from keras.src.random import SeedGenerator
6+
from keras.src.utils import backend_utils
67

78

89
@keras_export("keras.layers.MixUp")
@@ -66,36 +67,40 @@ def get_random_transformation(self, data, training=True, seed=None):
6667
}
6768

6869
def transform_images(self, images, transformation=None, training=True):
69-
images = self.backend.cast(images, self.compute_dtype)
70-
mix_weight = transformation["mix_weight"]
71-
permutation_order = transformation["permutation_order"]
72-
73-
mix_weight = self.backend.cast(
74-
self.backend.numpy.reshape(mix_weight, [-1, 1, 1, 1]),
75-
dtype=self.compute_dtype,
76-
)
77-
78-
mix_up_images = self.backend.cast(
79-
self.backend.numpy.take(images, permutation_order, axis=0),
80-
dtype=self.compute_dtype,
81-
)
82-
83-
images = mix_weight * images + (1.0 - mix_weight) * mix_up_images
84-
70+
def _mix_up_input(images, transformation):
71+
images = self.backend.cast(images, self.compute_dtype)
72+
mix_weight = transformation["mix_weight"]
73+
permutation_order = transformation["permutation_order"]
74+
mix_weight = self.backend.cast(
75+
self.backend.numpy.reshape(mix_weight, [-1, 1, 1, 1]),
76+
dtype=self.compute_dtype,
77+
)
78+
mix_up_images = self.backend.cast(
79+
self.backend.numpy.take(images, permutation_order, axis=0),
80+
dtype=self.compute_dtype,
81+
)
82+
images = mix_weight * images + (1.0 - mix_weight) * mix_up_images
83+
return images
84+
85+
if training:
86+
images = _mix_up_input(images, transformation)
8587
return images
8688

8789
def transform_labels(self, labels, transformation, training=True):
88-
mix_weight = transformation["mix_weight"]
89-
permutation_order = transformation["permutation_order"]
90-
91-
labels_for_mix_up = self.backend.numpy.take(
92-
labels, permutation_order, axis=0
93-
)
94-
95-
mix_weight = self.backend.numpy.reshape(mix_weight, [-1, 1])
96-
97-
labels = mix_weight * labels + (1.0 - mix_weight) * labels_for_mix_up
98-
90+
def _mix_up_labels(labels, transformation):
91+
mix_weight = transformation["mix_weight"]
92+
permutation_order = transformation["permutation_order"]
93+
labels_for_mix_up = self.backend.numpy.take(
94+
labels, permutation_order, axis=0
95+
)
96+
mix_weight = self.backend.numpy.reshape(mix_weight, [-1, 1])
97+
labels = (
98+
mix_weight * labels + (1.0 - mix_weight) * labels_for_mix_up
99+
)
100+
return labels
101+
102+
if training:
103+
labels = _mix_up_labels(labels, transformation)
99104
return labels
100105

101106
def transform_bounding_boxes(
@@ -104,33 +109,58 @@ def transform_bounding_boxes(
104109
transformation,
105110
training=True,
106111
):
107-
permutation_order = transformation["permutation_order"]
108-
boxes, classes = bounding_boxes["boxes"], bounding_boxes["classes"]
109-
boxes_for_mix_up = self.backend.numpy.take(boxes, permutation_order)
110-
classes_for_mix_up = self.backend.numpy.take(classes, permutation_order)
111-
boxes = self.backend.numpy.concat([boxes, boxes_for_mix_up], axis=1)
112-
classes = self.backend.numpy.concat(
113-
[classes, classes_for_mix_up], axis=1
114-
)
115-
return {"boxes": boxes, "classes": classes}
112+
def _mix_up_bounding_boxes(bounding_boxes, transformation):
113+
if backend_utils.in_tf_graph():
114+
self.backend.set_backend("tensorflow")
116115

117-
def transform_segmentation_masks(
118-
self, segmentation_masks, transformation, training=True
119-
):
120-
mix_weight = transformation["mix_weight"]
121-
permutation_order = transformation["permutation_order"]
116+
permutation_order = transformation["permutation_order"]
122117

123-
mix_weight = self.backend.numpy.reshape(mix_weight, [-1, 1, 1, 1])
118+
boxes, labels = bounding_boxes["boxes"], bounding_boxes["labels"]
119+
boxes_for_mix_up = self.backend.numpy.take(
120+
boxes, permutation_order, axis=0
121+
)
124122

125-
segmentation_masks_for_mix_up = self.backend.numpy.take(
126-
segmentation_masks, permutation_order
127-
)
123+
labels_for_mix_up = self.backend.numpy.take(
124+
labels, permutation_order, axis=0
125+
)
126+
boxes = self.backend.numpy.concatenate(
127+
[boxes, boxes_for_mix_up], axis=1
128+
)
128129

129-
segmentation_masks = (
130-
mix_weight * segmentation_masks
131-
+ (1.0 - mix_weight) * segmentation_masks_for_mix_up
132-
)
130+
labels = self.backend.numpy.concatenate(
131+
[labels, labels_for_mix_up], axis=0
132+
)
133+
134+
self.backend.reset()
133135

136+
return {"boxes": boxes, "labels": labels}
137+
138+
if training:
139+
bounding_boxes = _mix_up_bounding_boxes(
140+
bounding_boxes, transformation
141+
)
142+
return bounding_boxes
143+
144+
def transform_segmentation_masks(
145+
self, segmentation_masks, transformation, training=True
146+
):
147+
def _mix_up_segmentation_masks(segmentation_masks, transformation):
148+
mix_weight = transformation["mix_weight"]
149+
permutation_order = transformation["permutation_order"]
150+
mix_weight = self.backend.numpy.reshape(mix_weight, [-1, 1, 1, 1])
151+
segmentation_masks_for_mix_up = self.backend.numpy.take(
152+
segmentation_masks, permutation_order
153+
)
154+
segmentation_masks = (
155+
mix_weight * segmentation_masks
156+
+ (1.0 - mix_weight) * segmentation_masks_for_mix_up
157+
)
158+
return segmentation_masks
159+
160+
if training:
161+
segmentation_masks = _mix_up_segmentation_masks(
162+
segmentation_masks, transformation
163+
)
134164
return segmentation_masks
135165

136166
def compute_output_shape(self, input_shape):

keras/src/layers/preprocessing/image_preprocessing/mix_up_test.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
import pytest
33
from tensorflow import data as tf_data
44

5+
from keras.src import backend
56
from keras.src import layers
67
from keras.src import testing
8+
from keras.src.backend import convert_to_tensor
79

810

911
class MixUpTest(testing.TestCase):
@@ -21,6 +23,14 @@ def test_layer(self):
2123
run_training_check=not testing.tensorflow_uses_gpu(),
2224
)
2325

26+
def test_mix_up_inference(self):
27+
seed = 3481
28+
layer = layers.MixUp(alpha=0.2)
29+
np.random.seed(seed)
30+
inputs = np.random.randint(0, 255, size=(224, 224, 3))
31+
output = layer(inputs, training=False)
32+
self.assertAllClose(inputs, output)
33+
2434
def test_mix_up_basic_functionality(self):
2535
image = np.random.random((64, 64, 3))
2636
mix_up_layer = layers.MixUp(alpha=1)
@@ -63,3 +73,85 @@ def test_tf_data_compatibility(self):
6373
ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
6474
for output in ds.take(1):
6575
output.numpy()
76+
77+
def test_mix_up_bounding_boxes(self):
78+
data_format = backend.config.image_data_format()
79+
if data_format == "channels_last":
80+
image_shape = (10, 8, 3)
81+
else:
82+
image_shape = (3, 10, 8)
83+
input_image = np.random.random(image_shape)
84+
bounding_boxes = {
85+
"boxes": np.array(
86+
[
87+
[2, 1, 4, 3],
88+
[6, 4, 8, 6],
89+
]
90+
),
91+
"labels": np.array([1, 2]),
92+
}
93+
input_data = {"images": input_image, "bounding_boxes": bounding_boxes}
94+
95+
expected_boxes = [[2, 1, 4, 3, 6, 4, 8, 6], [6, 4, 8, 6, 2, 1, 4, 3]]
96+
97+
random_flip_layer = layers.MixUp(
98+
data_format=data_format,
99+
seed=42,
100+
bounding_box_format="xyxy",
101+
)
102+
103+
transformation = {
104+
"mix_weight": convert_to_tensor([0.5, 0.5]),
105+
"permutation_order": convert_to_tensor([1, 0]),
106+
}
107+
output = random_flip_layer.transform_bounding_boxes(
108+
input_data["bounding_boxes"],
109+
transformation=transformation,
110+
training=True,
111+
)
112+
self.assertAllClose(output["boxes"], expected_boxes)
113+
114+
def test_mix_up_tf_data_bounding_boxes(self):
115+
data_format = backend.config.image_data_format()
116+
if data_format == "channels_last":
117+
image_shape = (1, 10, 8, 3)
118+
else:
119+
image_shape = (1, 3, 10, 8)
120+
input_image = np.random.random(image_shape)
121+
bounding_boxes = {
122+
"boxes": np.array(
123+
[
124+
[
125+
[2, 1, 4, 3],
126+
[6, 4, 8, 6],
127+
]
128+
]
129+
),
130+
"labels": np.array([[1, 2]]),
131+
}
132+
133+
input_data = {"images": input_image, "bounding_boxes": bounding_boxes}
134+
expected_boxes = [[2, 1, 4, 3, 6, 4, 8, 6], [6, 4, 8, 6, 2, 1, 4, 3]]
135+
136+
ds = tf_data.Dataset.from_tensor_slices(input_data)
137+
layer = layers.MixUp(
138+
data_format=data_format,
139+
seed=42,
140+
bounding_box_format="xyxy",
141+
)
142+
143+
transformation = {
144+
"mix_weight": convert_to_tensor([0.5, 0.5]),
145+
"permutation_order": convert_to_tensor([1, 0]),
146+
}
147+
ds = ds.map(
148+
lambda x: layer.transform_bounding_boxes(
149+
x["bounding_boxes"],
150+
transformation=transformation,
151+
training=True,
152+
)
153+
)
154+
155+
output = next(iter(ds))
156+
expected_boxes = np.array(expected_boxes)
157+
self.assertAllClose(output["boxes"], expected_boxes)

keras/src/layers/preprocessing/image_preprocessing/random_hue.py

Lines changed: 41 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -87,46 +87,52 @@ def get_random_transformation(self, data, training=True, seed=None):
8787
return {"factor": invert * factor * 0.5}
8888

8989
def transform_images(self, images, transformation=None, training=True):
90-
images = self.backend.cast(images, self.compute_dtype)
91-
images = self._transform_value_range(images, self.value_range, (0, 1))
92-
adjust_factors = transformation["factor"]
93-
adjust_factors = self.backend.cast(adjust_factors, images.dtype)
94-
adjust_factors = self.backend.numpy.expand_dims(adjust_factors, -1)
95-
adjust_factors = self.backend.numpy.expand_dims(adjust_factors, -1)
96-
97-
images = self.backend.image.rgb_to_hsv(
98-
images, data_format=self.data_format
99-
)
100-
101-
if self.data_format == "channels_first":
102-
h_channel = images[:, 0, :, :] + adjust_factors
103-
h_channel = self.backend.numpy.where(
104-
h_channel > 1.0, h_channel - 1.0, h_channel
105-
)
106-
h_channel = self.backend.numpy.where(
107-
h_channel < 0.0, h_channel + 1.0, h_channel
90+
def _apply_random_hue(images, transformation):
91+
images = self.backend.cast(images, self.compute_dtype)
92+
images = self._transform_value_range(
93+
images, self.value_range, (0, 1)
10894
)
109-
images = self.backend.numpy.stack(
110-
[h_channel, images[:, 1, :, :], images[:, 2, :, :]], axis=1
95+
adjust_factors = transformation["factor"]
96+
adjust_factors = self.backend.cast(adjust_factors, images.dtype)
97+
adjust_factors = self.backend.numpy.expand_dims(adjust_factors, -1)
98+
adjust_factors = self.backend.numpy.expand_dims(adjust_factors, -1)
99+
images = self.backend.image.rgb_to_hsv(
100+
images, data_format=self.data_format
111101
)
112-
else:
113-
h_channel = images[..., 0] + adjust_factors
114-
h_channel = self.backend.numpy.where(
115-
h_channel > 1.0, h_channel - 1.0, h_channel
116-
)
117-
h_channel = self.backend.numpy.where(
118-
h_channel < 0.0, h_channel + 1.0, h_channel
102+
if self.data_format == "channels_first":
103+
h_channel = images[:, 0, :, :] + adjust_factors
104+
h_channel = self.backend.numpy.where(
105+
h_channel > 1.0, h_channel - 1.0, h_channel
106+
)
107+
h_channel = self.backend.numpy.where(
108+
h_channel < 0.0, h_channel + 1.0, h_channel
109+
)
110+
images = self.backend.numpy.stack(
111+
[h_channel, images[:, 1, :, :], images[:, 2, :, :]], axis=1
112+
)
113+
else:
114+
h_channel = images[..., 0] + adjust_factors
115+
h_channel = self.backend.numpy.where(
116+
h_channel > 1.0, h_channel - 1.0, h_channel
117+
)
118+
h_channel = self.backend.numpy.where(
119+
h_channel < 0.0, h_channel + 1.0, h_channel
120+
)
121+
images = self.backend.numpy.stack(
122+
[h_channel, images[..., 1], images[..., 2]], axis=-1
123+
)
124+
images = self.backend.image.hsv_to_rgb(
125+
images, data_format=self.data_format
119126
)
120-
images = self.backend.numpy.stack(
121-
[h_channel, images[..., 1], images[..., 2]], axis=-1
127+
images = self.backend.numpy.clip(images, 0, 1)
128+
images = self._transform_value_range(
129+
images, (0, 1), self.value_range
122130
)
123-
images = self.backend.image.hsv_to_rgb(
124-
images, data_format=self.data_format
125-
)
131+
images = self.backend.cast(images, self.compute_dtype)
132+
return images
126133

127-
images = self.backend.numpy.clip(images, 0, 1)
128-
images = self._transform_value_range(images, (0, 1), self.value_range)
129-
images = self.backend.cast(images, self.compute_dtype)
134+
if training:
135+
images = _apply_random_hue(images, transformation)
130136
return images
131137

132138
def transform_labels(self, labels, transformation, training=True):

keras/src/layers/preprocessing/image_preprocessing/random_hue_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ def test_layer(self):
2323
expected_output_shape=(8, 3, 4, 3),
2424
)
2525

26+
def test_random_hue_inference(self):
27+
seed = 3481
28+
layer = layers.RandomHue(0.2, [0, 1.0])
29+
np.random.seed(seed)
30+
inputs = np.random.randint(0, 255, size=(224, 224, 3))
31+
output = layer(inputs, training=False)
32+
self.assertAllClose(inputs, output)
33+
2634
def test_random_hue_value_range(self):
2735
image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)
2836

0 commit comments

Comments
 (0)