Skip to content

Commit 5fc7b6a

Browse files
authored
Add RandomGrayscale Layer (#20639)
* Add RandomGrayscale Layer * Fix torch tests * format * fix * fix * Fix torch tests
1 parent 4aa6a67 commit 5fc7b6a

File tree

5 files changed

+206
-0
lines changed

5 files changed

+206
-0
lines changed

keras/api/_tf_keras/keras/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,9 @@
164164
from keras.src.layers.preprocessing.image_preprocessing.random_flip import (
165165
RandomFlip,
166166
)
167+
from keras.src.layers.preprocessing.image_preprocessing.random_grayscale import (
168+
RandomGrayscale,
169+
)
167170
from keras.src.layers.preprocessing.image_preprocessing.random_hue import (
168171
RandomHue,
169172
)

keras/api/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,9 @@
164164
from keras.src.layers.preprocessing.image_preprocessing.random_flip import (
165165
RandomFlip,
166166
)
167+
from keras.src.layers.preprocessing.image_preprocessing.random_grayscale import (
168+
RandomGrayscale,
169+
)
167170
from keras.src.layers.preprocessing.image_preprocessing.random_hue import (
168171
RandomHue,
169172
)

keras/src/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@
108108
from keras.src.layers.preprocessing.image_preprocessing.random_flip import (
109109
RandomFlip,
110110
)
111+
from keras.src.layers.preprocessing.image_preprocessing.random_grayscale import (
112+
RandomGrayscale,
113+
)
111114
from keras.src.layers.preprocessing.image_preprocessing.random_hue import (
112115
RandomHue,
113116
)
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from keras.src import backend
2+
from keras.src.api_export import keras_export
3+
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
4+
BaseImagePreprocessingLayer,
5+
)
6+
7+
8+
@keras_export("keras.layers.RandomGrayscale")
9+
class RandomGrayscale(BaseImagePreprocessingLayer):
10+
"""Preprocessing layer for random conversion of RGB images to grayscale.
11+
12+
This layer randomly converts input images to grayscale with a specified
13+
factor. When applied, it maintains the original number of channels
14+
but sets all channels to the same grayscale value. This can be useful
15+
for data augmentation and training models to be robust to color
16+
variations.
17+
18+
The conversion preserves the perceived luminance of the original color
19+
image using standard RGB to grayscale conversion coefficients. Images
20+
that are not selected for conversion remain unchanged.
21+
22+
**Note:** This layer is safe to use inside a `tf.data` pipeline
23+
(independently of which backend you're using).
24+
25+
Args:
26+
factor: Float between 0 and 1, specifying the factor of
27+
converting each image to grayscale. Defaults to 0.5. A value of
28+
1.0 means all images will be converted, while 0.0 means no images
29+
will be converted.
30+
data_format: String, one of `"channels_last"` (default) or
31+
`"channels_first"`. The ordering of the dimensions in the inputs.
32+
`"channels_last"` corresponds to inputs with shape
33+
`(batch, height, width, channels)` while `"channels_first"`
34+
corresponds to inputs with shape
35+
`(batch, channels, height, width)`.
36+
37+
Input shape:
38+
3D (unbatched) or 4D (batched) tensor with shape:
39+
`(..., height, width, channels)`, in `"channels_last"` format,
40+
or `(..., channels, height, width)`, in `"channels_first"` format.
41+
42+
Output shape:
43+
Same as input shape. The output maintains the same number of channels
44+
as the input, even for grayscale-converted images where all channels
45+
will have the same value.
46+
"""
47+
48+
def __init__(self, factor=0.5, data_format=None, **kwargs):
49+
super().__init__(**kwargs)
50+
if factor < 0 or factor > 1:
51+
raise ValueError(
52+
"`factor` should be between 0 and 1. "
53+
f"Received: factor={factor}"
54+
)
55+
self.factor = factor
56+
self.data_format = backend.standardize_data_format(data_format)
57+
self.random_generator = self.backend.random.SeedGenerator()
58+
59+
def get_random_transformation(self, images, training=True, seed=None):
60+
random_values = self.backend.random.uniform(
61+
shape=(self.backend.core.shape(images)[0],),
62+
minval=0,
63+
maxval=1,
64+
seed=self.random_generator,
65+
)
66+
should_apply = self.backend.numpy.expand_dims(
67+
random_values < self.factor, axis=[1, 2, 3]
68+
)
69+
return should_apply
70+
71+
def transform_images(self, images, transformations=None, **kwargs):
72+
should_apply = (
73+
transformations
74+
if transformations is not None
75+
else self.get_random_transformation(images)
76+
)
77+
78+
grayscale_images = self.backend.image.rgb_to_grayscale(
79+
images, data_format=self.data_format
80+
)
81+
return self.backend.numpy.where(should_apply, grayscale_images, images)
82+
83+
def compute_output_shape(self, input_shape):
84+
return input_shape
85+
86+
def compute_output_spec(self, inputs, **kwargs):
87+
return inputs
88+
89+
def transform_bounding_boxes(self, bounding_boxes, **kwargs):
90+
return bounding_boxes
91+
92+
def transform_labels(self, labels, transformations=None, **kwargs):
93+
return labels
94+
95+
def transform_segmentation_masks(
96+
self, segmentation_masks, transformations=None, **kwargs
97+
):
98+
return segmentation_masks
99+
100+
def get_config(self):
101+
config = super().get_config()
102+
config.update({"factor": self.factor})
103+
return config
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import numpy as np
2+
import pytest
3+
from absl.testing import parameterized
4+
from tensorflow import data as tf_data
5+
6+
from keras.src import backend
7+
from keras.src import layers
8+
from keras.src import ops
9+
from keras.src import testing
10+
11+
12+
class RandomGrayscaleTest(testing.TestCase):
13+
@pytest.mark.requires_trainable_backend
14+
def test_layer(self):
15+
self.run_layer_test(
16+
layers.RandomGrayscale,
17+
init_kwargs={
18+
"factor": 0.5,
19+
"data_format": "channels_last",
20+
},
21+
input_shape=(1, 2, 2, 3),
22+
supports_masking=False,
23+
expected_output_shape=(1, 2, 2, 3),
24+
)
25+
26+
self.run_layer_test(
27+
layers.RandomGrayscale,
28+
init_kwargs={
29+
"factor": 0.5,
30+
"data_format": "channels_first",
31+
},
32+
input_shape=(1, 3, 2, 2),
33+
supports_masking=False,
34+
expected_output_shape=(1, 3, 2, 2),
35+
)
36+
37+
@parameterized.named_parameters(
38+
("channels_last", "channels_last"), ("channels_first", "channels_first")
39+
)
40+
def test_grayscale_conversion(self, data_format):
41+
if data_format == "channels_last":
42+
xs = np.random.uniform(0, 255, size=(2, 4, 4, 3)).astype(np.float32)
43+
layer = layers.RandomGrayscale(factor=1.0, data_format=data_format)
44+
transformed = ops.convert_to_numpy(layer(xs))
45+
self.assertEqual(transformed.shape[-1], 3)
46+
for img in transformed:
47+
r, g, b = img[:, :, 0], img[:, :, 1], img[:, :, 2]
48+
self.assertTrue(np.allclose(r, g) and np.allclose(g, b))
49+
else:
50+
xs = np.random.uniform(0, 255, size=(2, 3, 4, 4)).astype(np.float32)
51+
layer = layers.RandomGrayscale(factor=1.0, data_format=data_format)
52+
transformed = ops.convert_to_numpy(layer(xs))
53+
self.assertEqual(transformed.shape[1], 3)
54+
for img in transformed:
55+
r, g, b = img[0], img[1], img[2]
56+
self.assertTrue(np.allclose(r, g) and np.allclose(g, b))
57+
58+
def test_invalid_factor(self):
59+
with self.assertRaises(ValueError):
60+
layers.RandomGrayscale(factor=-0.1)
61+
62+
with self.assertRaises(ValueError):
63+
layers.RandomGrayscale(factor=1.1)
64+
65+
def test_tf_data_compatibility(self):
66+
data_format = backend.config.image_data_format()
67+
if data_format == "channels_last":
68+
input_data = np.random.random((2, 8, 8, 3)) * 255
69+
else:
70+
input_data = np.random.random((2, 3, 8, 8)) * 255
71+
72+
layer = layers.RandomGrayscale(factor=0.5, data_format=data_format)
73+
ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
74+
75+
for output in ds.take(1):
76+
output_array = output.numpy()
77+
self.assertEqual(output_array.shape, input_data.shape)
78+
79+
def test_grayscale_with_single_color_image(self):
80+
test_cases = [
81+
(np.full((1, 4, 4, 3), 128, dtype=np.float32), "channels_last"),
82+
(np.full((1, 3, 4, 4), 128, dtype=np.float32), "channels_first"),
83+
]
84+
85+
for xs, data_format in test_cases:
86+
layer = layers.RandomGrayscale(factor=1.0, data_format=data_format)
87+
transformed = ops.convert_to_numpy(layer(xs))
88+
89+
if data_format == "channels_last":
90+
unique_vals = np.unique(transformed[0, :, :, 0])
91+
self.assertEqual(len(unique_vals), 1)
92+
else:
93+
unique_vals = np.unique(transformed[0, 0, :, :])
94+
self.assertEqual(len(unique_vals), 1)

0 commit comments

Comments
 (0)