Skip to content

Commit 345cecc

Browse files
authored
Add implementations for random_hue (#20620)
* Add implementations for random_hue * Correct failed test cases * Correct misspellings * Update example on description * Correct test case failed.
1 parent 007f5e7 commit 345cecc

File tree

5 files changed

+273
-0
lines changed

5 files changed

+273
-0
lines changed

keras/api/_tf_keras/keras/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,9 @@
164164
from keras.src.layers.preprocessing.image_preprocessing.random_flip import (
165165
RandomFlip,
166166
)
167+
from keras.src.layers.preprocessing.image_preprocessing.random_hue import (
168+
RandomHue,
169+
)
167170
from keras.src.layers.preprocessing.image_preprocessing.random_rotation import (
168171
RandomRotation,
169172
)

keras/api/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,9 @@
164164
from keras.src.layers.preprocessing.image_preprocessing.random_flip import (
165165
RandomFlip,
166166
)
167+
from keras.src.layers.preprocessing.image_preprocessing.random_hue import (
168+
RandomHue,
169+
)
167170
from keras.src.layers.preprocessing.image_preprocessing.random_rotation import (
168171
RandomRotation,
169172
)

keras/src/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@
107107
from keras.src.layers.preprocessing.image_preprocessing.random_flip import (
108108
RandomFlip,
109109
)
110+
from keras.src.layers.preprocessing.image_preprocessing.random_hue import (
111+
RandomHue,
112+
)
110113
from keras.src.layers.preprocessing.image_preprocessing.random_rotation import (
111114
RandomRotation,
112115
)
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
from keras.src.api_export import keras_export
2+
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
3+
BaseImagePreprocessingLayer,
4+
)
5+
from keras.src.random import SeedGenerator
6+
7+
8+
def transform_value_range(images, original_range, target_range):
9+
if (
10+
original_range[0] == target_range[0]
11+
and original_range[1] == target_range[1]
12+
):
13+
return images
14+
15+
original_min_value, original_max_value = original_range
16+
target_min_value, target_max_value = target_range
17+
18+
# images in the [0, 1] scale
19+
images = (images - original_min_value) / (
20+
original_max_value - original_min_value
21+
)
22+
23+
scale_factor = target_max_value - target_min_value
24+
return (images * scale_factor) + target_min_value
25+
26+
27+
@keras_export("keras.layers.RandomHue")
28+
class RandomHue(BaseImagePreprocessingLayer):
29+
"""Randomly adjusts the hue on given images.
30+
31+
This layer will randomly increase/reduce the hue for the input RGB
32+
images.
33+
34+
The image hue is adjusted by converting the image(s) to HSV and rotating the
35+
hue channel (H) by delta. The image is then converted back to RGB.
36+
37+
Args:
38+
factor: A single float or a tuple of two floats.
39+
`factor` controls the extent to which the
40+
image hue is impacted. `factor=0.0` makes this layer perform a
41+
no-op operation, while a value of `1.0` performs the most aggressive
42+
contrast adjustment available. If a tuple is used, a `factor` is
43+
sampled between the two values for every image augmented. If a
44+
single float is used, a value between `0.0` and the passed float is
45+
sampled. In order to ensure the value is always the same, please
46+
pass a tuple with two identical floats: `(0.5, 0.5)`.
47+
value_range: the range of values the incoming images will have.
48+
Represented as a two-number tuple written `[low, high]`. This is
49+
typically either `[0, 1]` or `[0, 255]` depending on how your
50+
preprocessing pipeline is set up.
51+
seed: Integer. Used to create a random seed.
52+
53+
```python
54+
(images, labels), _ = keras.datasets.cifar10.load_data()
55+
random_hue = keras.layers.RandomHue(factor=0.5, value_range=[0, 1])
56+
augmented_images = random_hue(images)
57+
```
58+
"""
59+
60+
def __init__(
61+
self, factor, value_range, data_format=None, seed=None, **kwargs
62+
):
63+
super().__init__(data_format=data_format, **kwargs)
64+
self.factor = factor
65+
self.value_range = value_range
66+
self.seed = seed
67+
self.generator = SeedGenerator(seed)
68+
69+
def parse_factor(
70+
self, min_value=0.0, max_value=1.0, param_name="factor", shape=None
71+
):
72+
factors = self.factor
73+
if isinstance(factors, float) or isinstance(factors, int):
74+
factors = (min_value, factors)
75+
76+
if factors[0] > factors[1]:
77+
raise ValueError(
78+
f"`{param_name}[0] > {param_name}[1]`, "
79+
f"`{param_name}[0]` must be "
80+
f"<= `{param_name}[1]`. Got `{param_name}={factors}`"
81+
)
82+
if (min_value is not None and factors[0] < min_value) or (
83+
max_value is not None and factors[1] > max_value
84+
):
85+
raise ValueError(
86+
f"`{param_name}` should be inside of range "
87+
f"[{min_value}, {max_value}]. Got {param_name}={factors}"
88+
)
89+
90+
if factors[0] == factors[1]:
91+
return self.backend.numpy.ones(shape=shape) * factors[0]
92+
93+
return self.backend.random.uniform(
94+
shape,
95+
seed=self.generator,
96+
minval=factors[0],
97+
maxval=factors[1],
98+
)
99+
100+
def get_random_transformation(self, data, training=True, seed=None):
101+
if isinstance(data, dict):
102+
images = data["images"]
103+
else:
104+
images = data
105+
images_shape = self.backend.shape(images)
106+
rank = len(images_shape)
107+
if rank == 3:
108+
batch_size = 1
109+
elif rank == 4:
110+
batch_size = images_shape[0]
111+
else:
112+
raise ValueError(
113+
"Expected the input image to be rank 3 or 4. Received "
114+
f"inputs.shape={images_shape}"
115+
)
116+
117+
if seed is None:
118+
seed = self._get_seed_generator(self.backend._backend)
119+
invert = self.backend.random.uniform((1,), seed=seed)
120+
121+
invert = self.backend.numpy.where(
122+
invert > 0.5,
123+
-self.backend.numpy.ones_like(invert),
124+
self.backend.numpy.ones_like(invert),
125+
)
126+
127+
factor = self.parse_factor(shape=(batch_size,))
128+
129+
return {"factor": invert * factor * 0.5}
130+
131+
def transform_images(self, images, transformation=None, training=True):
132+
images = transform_value_range(images, self.value_range, (0, 1))
133+
adjust_factors = transformation["factor"]
134+
adjust_factors = self.backend.cast(adjust_factors, images.dtype)
135+
adjust_factors = self.backend.numpy.expand_dims(adjust_factors, -1)
136+
adjust_factors = self.backend.numpy.expand_dims(adjust_factors, -1)
137+
138+
images = self.backend.image.rgb_to_hsv(
139+
images, data_format=self.data_format
140+
)
141+
142+
if self.data_format == "channels_first":
143+
h_channel = images[:, 0, :, :] + adjust_factors
144+
h_channel = self.backend.numpy.where(
145+
h_channel > 1.0, h_channel - 1.0, h_channel
146+
)
147+
h_channel = self.backend.numpy.where(
148+
h_channel < 0.0, h_channel + 1.0, h_channel
149+
)
150+
images = self.backend.numpy.stack(
151+
[h_channel, images[:, 1, :, :], images[:, 2, :, :]], axis=1
152+
)
153+
else:
154+
h_channel = images[..., 0] + adjust_factors
155+
h_channel = self.backend.numpy.where(
156+
h_channel > 1.0, h_channel - 1.0, h_channel
157+
)
158+
h_channel = self.backend.numpy.where(
159+
h_channel < 0.0, h_channel + 1.0, h_channel
160+
)
161+
images = self.backend.numpy.stack(
162+
[h_channel, images[..., 1], images[..., 2]], axis=-1
163+
)
164+
images = self.backend.image.hsv_to_rgb(
165+
images, data_format=self.data_format
166+
)
167+
168+
images = self.backend.numpy.clip(images, 0, 1)
169+
images = transform_value_range(images, (0, 1), self.value_range)
170+
171+
return images
172+
173+
def transform_labels(self, labels, transformation, training=True):
174+
return labels
175+
176+
def transform_segmentation_masks(
177+
self, segmentation_masks, transformation, training=True
178+
):
179+
return segmentation_masks
180+
181+
def transform_bounding_boxes(
182+
self, bounding_boxes, transformation, training=True
183+
):
184+
return bounding_boxes
185+
186+
def get_config(self):
187+
config = super().get_config()
188+
config.update(
189+
{
190+
"factor": self.factor,
191+
"value_range": self.value_range,
192+
"seed": self.seed,
193+
}
194+
)
195+
return config
196+
197+
def compute_output_shape(self, input_shape):
198+
return input_shape
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import numpy as np
2+
import pytest
3+
from tensorflow import data as tf_data
4+
5+
import keras
6+
from keras.src import backend
7+
from keras.src import layers
8+
from keras.src import testing
9+
10+
11+
class RandomHueTest(testing.TestCase):
12+
@pytest.mark.requires_trainable_backend
13+
def test_layer(self):
14+
self.run_layer_test(
15+
layers.RandomHue,
16+
init_kwargs={
17+
"factor": 0.75,
18+
"value_range": (20, 200),
19+
"seed": 1,
20+
},
21+
input_shape=(8, 3, 4, 3),
22+
supports_masking=False,
23+
expected_output_shape=(8, 3, 4, 3),
24+
)
25+
26+
def test_random_hue_value_range(self):
27+
image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)
28+
29+
layer = layers.RandomHue(0.2, (0, 255))
30+
adjusted_image = layer(image)
31+
32+
self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0))
33+
self.assertTrue(keras.ops.numpy.all(adjusted_image <= 1))
34+
35+
def test_random_hue_no_change_with_zero_factor(self):
36+
data_format = backend.config.image_data_format()
37+
if data_format == "channels_last":
38+
inputs = keras.random.randint((224, 224, 3), 0, 255)
39+
else:
40+
inputs = keras.random.randint((3, 224, 224), 0, 255)
41+
42+
layer = layers.RandomHue(0, (0, 255), data_format=data_format)
43+
output = layer(inputs, training=False)
44+
self.assertAllClose(inputs, output, atol=1e-3, rtol=1e-5)
45+
46+
def test_random_hue_randomness(self):
47+
image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)[:5]
48+
49+
layer = layers.RandomHue(0.2, (0, 255))
50+
adjusted_images = layer(image)
51+
52+
self.assertNotAllClose(adjusted_images, image)
53+
54+
def test_tf_data_compatibility(self):
55+
data_format = backend.config.image_data_format()
56+
if data_format == "channels_last":
57+
input_data = np.random.random((2, 8, 8, 3))
58+
else:
59+
input_data = np.random.random((2, 3, 8, 8))
60+
layer = layers.RandomHue(
61+
factor=0.5, value_range=[0, 1], data_format=data_format, seed=1337
62+
)
63+
64+
ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
65+
for output in ds.take(1):
66+
output.numpy()

0 commit comments

Comments
 (0)