Skip to content

Commit f54c127

Browse files
authored
Add random_color_degeneration processing layer (#20679)
* Add random_color_degeneration processing layer * Fix mistypo * Correct failed test case
1 parent df002a9 commit f54c127

File tree

5 files changed

+218
-0
lines changed

5 files changed

+218
-0
lines changed

keras/api/_tf_keras/keras/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@
155155
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
156156
RandomBrightness,
157157
)
158+
from keras.src.layers.preprocessing.image_preprocessing.random_color_degeneration import (
159+
RandomColorDegeneration,
160+
)
158161
from keras.src.layers.preprocessing.image_preprocessing.random_color_jitter import (
159162
RandomColorJitter,
160163
)

keras/api/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@
155155
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
156156
RandomBrightness,
157157
)
158+
from keras.src.layers.preprocessing.image_preprocessing.random_color_degeneration import (
159+
RandomColorDegeneration,
160+
)
158161
from keras.src.layers.preprocessing.image_preprocessing.random_color_jitter import (
159162
RandomColorJitter,
160163
)

keras/src/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@
9999
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
100100
RandomBrightness,
101101
)
102+
from keras.src.layers.preprocessing.image_preprocessing.random_color_degeneration import (
103+
RandomColorDegeneration,
104+
)
102105
from keras.src.layers.preprocessing.image_preprocessing.random_color_jitter import (
103106
RandomColorJitter,
104107
)
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
from keras.src.api_export import keras_export
2+
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
3+
BaseImagePreprocessingLayer,
4+
)
5+
from keras.src.random import SeedGenerator
6+
7+
8+
@keras_export("keras.layers.RandomColorDegeneration")
9+
class RandomColorDegeneration(BaseImagePreprocessingLayer):
10+
"""Randomly performs the color degeneration operation on given images.
11+
12+
The sharpness operation first converts an image to gray scale, then back to
13+
color. It then takes a weighted average between original image and the
14+
degenerated image. This makes colors appear more dull.
15+
16+
Args:
17+
factor: A tuple of two floats or a single float.
18+
`factor` controls the extent to which the
19+
image sharpness is impacted. `factor=0.0` makes this layer perform a
20+
no-op operation, while a value of 1.0 uses the degenerated result
21+
entirely. Values between 0 and 1 result in linear interpolation
22+
between the original image and the sharpened image.
23+
Values should be between `0.0` and `1.0`. If a tuple is used, a
24+
`factor` is sampled between the two values for every image
25+
augmented. If a single float is used, a value between `0.0` and the
26+
passed float is sampled. In order to ensure the value is always the
27+
same, please pass a tuple with two identical floats: `(0.5, 0.5)`.
28+
seed: Integer. Used to create a random seed.
29+
"""
30+
31+
_VALUE_RANGE_VALIDATION_ERROR = (
32+
"The `value_range` argument should be a list of two numbers. "
33+
)
34+
35+
def __init__(
36+
self,
37+
factor,
38+
value_range=(0, 255),
39+
data_format=None,
40+
seed=None,
41+
**kwargs,
42+
):
43+
super().__init__(data_format=data_format, **kwargs)
44+
self._set_factor(factor)
45+
self._set_value_range(value_range)
46+
self.seed = seed
47+
self.generator = SeedGenerator(seed)
48+
49+
def _set_value_range(self, value_range):
50+
if not isinstance(value_range, (tuple, list)):
51+
raise ValueError(
52+
self._VALUE_RANGE_VALIDATION_ERROR
53+
+ f"Received: value_range={value_range}"
54+
)
55+
if len(value_range) != 2:
56+
raise ValueError(
57+
self._VALUE_RANGE_VALIDATION_ERROR
58+
+ f"Received: value_range={value_range}"
59+
)
60+
self.value_range = sorted(value_range)
61+
62+
def get_random_transformation(self, data, training=True, seed=None):
63+
if isinstance(data, dict):
64+
images = data["images"]
65+
else:
66+
images = data
67+
images_shape = self.backend.shape(images)
68+
rank = len(images_shape)
69+
if rank == 3:
70+
batch_size = 1
71+
elif rank == 4:
72+
batch_size = images_shape[0]
73+
else:
74+
raise ValueError(
75+
"Expected the input image to be rank 3 or 4. Received: "
76+
f"inputs.shape={images_shape}"
77+
)
78+
79+
if seed is None:
80+
seed = self._get_seed_generator(self.backend._backend)
81+
82+
factor = self.backend.random.uniform(
83+
(batch_size, 1, 1, 1),
84+
minval=self.factor[0],
85+
maxval=self.factor[1],
86+
seed=seed,
87+
)
88+
factor = factor
89+
return {"factor": factor}
90+
91+
def transform_images(self, images, transformation=None, training=True):
92+
if training:
93+
images = self.backend.cast(images, self.compute_dtype)
94+
factor = self.backend.cast(
95+
transformation["factor"], self.compute_dtype
96+
)
97+
degenerates = self.backend.image.rgb_to_grayscale(
98+
images, data_format=self.data_format
99+
)
100+
images = images + factor * (degenerates - images)
101+
images = self.backend.numpy.clip(
102+
images, self.value_range[0], self.value_range[1]
103+
)
104+
images = self.backend.cast(images, self.compute_dtype)
105+
return images
106+
107+
def transform_labels(self, labels, transformation, training=True):
108+
return labels
109+
110+
def transform_segmentation_masks(
111+
self, segmentation_masks, transformation, training=True
112+
):
113+
return segmentation_masks
114+
115+
def transform_bounding_boxes(
116+
self, bounding_boxes, transformation, training=True
117+
):
118+
return bounding_boxes
119+
120+
def get_config(self):
121+
config = super().get_config()
122+
config.update(
123+
{
124+
"factor": self.factor,
125+
"value_range": self.value_range,
126+
"seed": self.seed,
127+
}
128+
)
129+
return config
130+
131+
def compute_output_shape(self, input_shape):
132+
return input_shape
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import numpy as np
2+
import pytest
3+
from tensorflow import data as tf_data
4+
5+
import keras
6+
from keras.src import backend
7+
from keras.src import layers
8+
from keras.src import testing
9+
10+
11+
class RandomColorDegenerationTest(testing.TestCase):
12+
@pytest.mark.requires_trainable_backend
13+
def test_layer(self):
14+
self.run_layer_test(
15+
layers.RandomColorDegeneration,
16+
init_kwargs={
17+
"factor": 0.75,
18+
"value_range": (0, 1),
19+
"seed": 1,
20+
},
21+
input_shape=(8, 3, 4, 3),
22+
supports_masking=False,
23+
expected_output_shape=(8, 3, 4, 3),
24+
)
25+
26+
def test_random_color_degeneration_value_range(self):
27+
image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)
28+
29+
layer = layers.RandomColorDegeneration(0.2, value_range=(0, 1))
30+
adjusted_image = layer(image)
31+
32+
self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0))
33+
self.assertTrue(keras.ops.numpy.all(adjusted_image <= 1))
34+
35+
def test_random_color_degeneration_no_op(self):
36+
data_format = backend.config.image_data_format()
37+
if data_format == "channels_last":
38+
inputs = np.random.random((2, 8, 8, 3))
39+
else:
40+
inputs = np.random.random((2, 3, 8, 8))
41+
42+
layer = layers.RandomColorDegeneration((0.5, 0.5))
43+
output = layer(inputs, training=False)
44+
self.assertAllClose(inputs, output, atol=1e-3, rtol=1e-5)
45+
46+
def test_random_color_degeneration_factor_zero(self):
47+
data_format = backend.config.image_data_format()
48+
if data_format == "channels_last":
49+
inputs = np.random.random((2, 8, 8, 3))
50+
else:
51+
inputs = np.random.random((2, 3, 8, 8))
52+
layer = layers.RandomColorDegeneration(factor=(0.0, 0.0))
53+
result = layer(inputs)
54+
55+
self.assertAllClose(inputs, result, atol=1e-3, rtol=1e-5)
56+
57+
def test_random_color_degeneration_randomness(self):
58+
image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)[:5]
59+
60+
layer = layers.RandomColorDegeneration(0.2)
61+
adjusted_images = layer(image)
62+
63+
self.assertNotAllClose(adjusted_images, image)
64+
65+
def test_tf_data_compatibility(self):
66+
data_format = backend.config.image_data_format()
67+
if data_format == "channels_last":
68+
input_data = np.random.random((2, 8, 8, 3))
69+
else:
70+
input_data = np.random.random((2, 3, 8, 8))
71+
layer = layers.RandomColorDegeneration(
72+
factor=0.5, data_format=data_format, seed=1337
73+
)
74+
75+
ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
76+
for output in ds.take(1):
77+
output.numpy()

0 commit comments

Comments
 (0)