Skip to content

Commit 9bcf324

Browse files
authored
Add implementations for random_saturation (#20646)
* Correct bug for MixUp initialization. * Update format indent * Add implementations for random_saturation * change parse_factor method to inner method. * correct test cases failed. * correct failed test cases * Add training argument check condition * correct source code * add value_range args description * update description example * change _apply_random_saturation method to inline
1 parent 2b6c800 commit 9bcf324

File tree

5 files changed

+273
-0
lines changed

5 files changed

+273
-0
lines changed

keras/api/_tf_keras/keras/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,9 @@
173173
from keras.src.layers.preprocessing.image_preprocessing.random_rotation import (
174174
RandomRotation,
175175
)
176+
from keras.src.layers.preprocessing.image_preprocessing.random_saturation import (
177+
RandomSaturation,
178+
)
176179
from keras.src.layers.preprocessing.image_preprocessing.random_translation import (
177180
RandomTranslation,
178181
)

keras/api/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,9 @@
173173
from keras.src.layers.preprocessing.image_preprocessing.random_rotation import (
174174
RandomRotation,
175175
)
176+
from keras.src.layers.preprocessing.image_preprocessing.random_saturation import (
177+
RandomSaturation,
178+
)
176179
from keras.src.layers.preprocessing.image_preprocessing.random_translation import (
177180
RandomTranslation,
178181
)

keras/src/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@
117117
from keras.src.layers.preprocessing.image_preprocessing.random_rotation import (
118118
RandomRotation,
119119
)
120+
from keras.src.layers.preprocessing.image_preprocessing.random_saturation import (
121+
RandomSaturation,
122+
)
120123
from keras.src.layers.preprocessing.image_preprocessing.random_translation import (
121124
RandomTranslation,
122125
)
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
from keras.src.api_export import keras_export
2+
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
3+
BaseImagePreprocessingLayer,
4+
)
5+
from keras.src.random import SeedGenerator
6+
7+
8+
@keras_export("keras.layers.RandomSaturation")
9+
class RandomSaturation(BaseImagePreprocessingLayer):
10+
"""Randomly adjusts the saturation on given images.
11+
12+
This layer will randomly increase/reduce the saturation for the input RGB
13+
images.
14+
15+
Args:
16+
factor: A tuple of two floats or a single float.
17+
`factor` controls the extent to which the image saturation
18+
is impacted. `factor=0.5` makes this layer perform a no-op
19+
operation. `factor=0.0` makes the image fully grayscale.
20+
`factor=1.0` makes the image fully saturated. Values should
21+
be between `0.0` and `1.0`. If a tuple is used, a `factor`
22+
is sampled between the two values for every image augmented.
23+
If a single float is used, a value between `0.0` and the passed
24+
float is sampled. To ensure the value is always the same,
25+
pass a tuple with two identical floats: `(0.5, 0.5)`.
26+
value_range: the range of values the incoming images will have.
27+
Represented as a two-number tuple written `[low, high]`. This is
28+
typically either `[0, 1]` or `[0, 255]` depending on how your
29+
preprocessing pipeline is set up.
30+
seed: Integer. Used to create a random seed.
31+
32+
Example:
33+
```python
34+
(images, labels), _ = keras.datasets.cifar10.load_data()
35+
images = images.astype("float32")
36+
random_saturation = keras.layers.RandomSaturation(factor=0.2)
37+
augmented_images = random_saturation(images)
38+
```
39+
"""
40+
41+
_VALUE_RANGE_VALIDATION_ERROR = (
42+
"The `value_range` argument should be a list of two numbers. "
43+
)
44+
45+
def __init__(
46+
self,
47+
factor,
48+
value_range=(0, 255),
49+
data_format=None,
50+
seed=None,
51+
**kwargs,
52+
):
53+
super().__init__(data_format=data_format, **kwargs)
54+
self._set_factor(factor)
55+
self._set_value_range(value_range)
56+
self.seed = seed
57+
self.generator = SeedGenerator(seed)
58+
59+
def _set_value_range(self, value_range):
60+
if not isinstance(value_range, (tuple, list)):
61+
raise ValueError(
62+
self._VALUE_RANGE_VALIDATION_ERROR
63+
+ f"Received: value_range={value_range}"
64+
)
65+
if len(value_range) != 2:
66+
raise ValueError(
67+
self._VALUE_RANGE_VALIDATION_ERROR
68+
+ f"Received: value_range={value_range}"
69+
)
70+
self.value_range = sorted(value_range)
71+
72+
def get_random_transformation(self, data, training=True, seed=None):
73+
if isinstance(data, dict):
74+
images = data["images"]
75+
else:
76+
images = data
77+
images_shape = self.backend.shape(images)
78+
rank = len(images_shape)
79+
if rank == 3:
80+
batch_size = 1
81+
elif rank == 4:
82+
batch_size = images_shape[0]
83+
else:
84+
raise ValueError(
85+
"Expected the input image to be rank 3 or 4. Received: "
86+
f"inputs.shape={images_shape}"
87+
)
88+
89+
if seed is None:
90+
seed = self._get_seed_generator(self.backend._backend)
91+
92+
factor = self.backend.random.uniform(
93+
(batch_size,),
94+
minval=self.factor[0],
95+
maxval=self.factor[1],
96+
seed=seed,
97+
)
98+
factor = factor / (1 - factor)
99+
return {"factor": factor}
100+
101+
def transform_images(self, images, transformation=None, training=True):
102+
def _apply_random_saturation(images, transformation):
103+
adjust_factors = transformation["factor"]
104+
adjust_factors = self.backend.cast(
105+
adjust_factors, self.compute_dtype
106+
)
107+
adjust_factors = self.backend.numpy.reshape(
108+
adjust_factors, self.backend.shape(adjust_factors) + (1, 1)
109+
)
110+
images = self.backend.image.rgb_to_hsv(
111+
images, data_format=self.data_format
112+
)
113+
if self.data_format == "channels_first":
114+
s_channel = self.backend.numpy.multiply(
115+
images[:, 1, :, :], adjust_factors
116+
)
117+
s_channel = self.backend.numpy.clip(
118+
s_channel, self.value_range[0], self.value_range[1]
119+
)
120+
images = self.backend.numpy.stack(
121+
[images[:, 0, :, :], s_channel, images[:, 2, :, :]], axis=1
122+
)
123+
else:
124+
s_channel = self.backend.numpy.multiply(
125+
images[..., 1], adjust_factors
126+
)
127+
s_channel = self.backend.numpy.clip(
128+
s_channel, self.value_range[0], self.value_range[1]
129+
)
130+
images = self.backend.numpy.stack(
131+
[images[..., 0], s_channel, images[..., 2]], axis=-1
132+
)
133+
images = self.backend.image.hsv_to_rgb(
134+
images, data_format=self.data_format
135+
)
136+
return images
137+
138+
if training:
139+
images = _apply_random_saturation(images, transformation)
140+
return images
141+
142+
def transform_labels(self, labels, transformation, training=True):
143+
return labels
144+
145+
def transform_segmentation_masks(
146+
self, segmentation_masks, transformation, training=True
147+
):
148+
return segmentation_masks
149+
150+
def transform_bounding_boxes(
151+
self, bounding_boxes, transformation, training=True
152+
):
153+
return bounding_boxes
154+
155+
def get_config(self):
156+
config = super().get_config()
157+
config.update(
158+
{
159+
"factor": self.factor,
160+
"value_range": self.value_range,
161+
"seed": self.seed,
162+
}
163+
)
164+
return config
165+
166+
def compute_output_shape(self, input_shape):
167+
return input_shape
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import numpy as np
2+
import pytest
3+
from tensorflow import data as tf_data
4+
5+
import keras
6+
from keras.src import backend
7+
from keras.src import layers
8+
from keras.src import testing
9+
10+
11+
class RandomSaturationTest(testing.TestCase):
12+
@pytest.mark.requires_trainable_backend
13+
def test_layer(self):
14+
self.run_layer_test(
15+
layers.RandomSaturation,
16+
init_kwargs={
17+
"factor": 0.75,
18+
"seed": 1,
19+
},
20+
input_shape=(8, 3, 4, 3),
21+
supports_masking=False,
22+
expected_output_shape=(8, 3, 4, 3),
23+
)
24+
25+
def test_random_saturation_value_range(self):
26+
image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)
27+
28+
layer = layers.RandomSaturation(0.2)
29+
adjusted_image = layer(image)
30+
31+
self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0))
32+
self.assertTrue(keras.ops.numpy.all(adjusted_image <= 1))
33+
34+
def test_random_saturation_no_op(self):
35+
data_format = backend.config.image_data_format()
36+
if data_format == "channels_last":
37+
inputs = np.random.random((2, 8, 8, 3))
38+
else:
39+
inputs = np.random.random((2, 3, 8, 8))
40+
41+
layer = layers.RandomSaturation((0.5, 0.5))
42+
output = layer(inputs, training=False)
43+
self.assertAllClose(inputs, output, atol=1e-3, rtol=1e-5)
44+
45+
def test_random_saturation_full_grayscale(self):
46+
data_format = backend.config.image_data_format()
47+
if data_format == "channels_last":
48+
inputs = np.random.random((2, 8, 8, 3))
49+
else:
50+
inputs = np.random.random((2, 3, 8, 8))
51+
layer = layers.RandomSaturation(factor=(0.0, 0.0))
52+
result = layer(inputs)
53+
54+
if data_format == "channels_last":
55+
self.assertAllClose(result[..., 0], result[..., 1])
56+
self.assertAllClose(result[..., 1], result[..., 2])
57+
else:
58+
self.assertAllClose(result[:, 0, :, :], result[:, 1, :, :])
59+
self.assertAllClose(result[:, 1, :, :], result[:, 2, :, :])
60+
61+
def test_random_saturation_full_saturation(self):
62+
data_format = backend.config.image_data_format()
63+
if data_format == "channels_last":
64+
inputs = np.random.random((2, 8, 8, 3))
65+
else:
66+
inputs = np.random.random((2, 3, 8, 8))
67+
layer = layers.RandomSaturation(factor=(1.0, 1.0))
68+
result = layer(inputs)
69+
70+
hsv = backend.image.rgb_to_hsv(result)
71+
s_channel = hsv[..., 1]
72+
73+
self.assertAllClose(
74+
keras.ops.numpy.max(s_channel), layer.value_range[1]
75+
)
76+
77+
def test_random_saturation_randomness(self):
78+
image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)[:5]
79+
80+
layer = layers.RandomSaturation(0.2)
81+
adjusted_images = layer(image)
82+
83+
self.assertNotAllClose(adjusted_images, image)
84+
85+
def test_tf_data_compatibility(self):
86+
data_format = backend.config.image_data_format()
87+
if data_format == "channels_last":
88+
input_data = np.random.random((2, 8, 8, 3))
89+
else:
90+
input_data = np.random.random((2, 3, 8, 8))
91+
layer = layers.RandomSaturation(
92+
factor=0.5, data_format=data_format, seed=1337
93+
)
94+
95+
ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
96+
for output in ds.take(1):
97+
output.numpy()

0 commit comments

Comments
 (0)