Skip to content

Commit be1191f

Browse files
authored
Add random_posterization processing layer (#20688)
* Add random_posterization processing layer * Add test cases * correct failed case
1 parent 67d1ddf commit be1191f

File tree

5 files changed

+245
-0
lines changed

5 files changed

+245
-0
lines changed

keras/api/_tf_keras/keras/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,9 @@
176176
from keras.src.layers.preprocessing.image_preprocessing.random_hue import (
177177
RandomHue,
178178
)
179+
from keras.src.layers.preprocessing.image_preprocessing.random_posterization import (
180+
RandomPosterization,
181+
)
179182
from keras.src.layers.preprocessing.image_preprocessing.random_rotation import (
180183
RandomRotation,
181184
)

keras/api/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,9 @@
176176
from keras.src.layers.preprocessing.image_preprocessing.random_hue import (
177177
RandomHue,
178178
)
179+
from keras.src.layers.preprocessing.image_preprocessing.random_posterization import (
180+
RandomPosterization,
181+
)
179182
from keras.src.layers.preprocessing.image_preprocessing.random_rotation import (
180183
RandomRotation,
181184
)

keras/src/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@
120120
from keras.src.layers.preprocessing.image_preprocessing.random_hue import (
121121
RandomHue,
122122
)
123+
from keras.src.layers.preprocessing.image_preprocessing.random_posterization import (
124+
RandomPosterization,
125+
)
123126
from keras.src.layers.preprocessing.image_preprocessing.random_rotation import (
124127
RandomRotation,
125128
)
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
from keras.src.api_export import keras_export
2+
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
3+
BaseImagePreprocessingLayer,
4+
)
5+
6+
7+
@keras_export("keras.layers.RandomPosterization")
8+
class RandomPosterization(BaseImagePreprocessingLayer):
9+
"""Reduces the number of bits for each color channel.
10+
11+
References:
12+
- [AutoAugment: Learning Augmentation Policies from Data](https://arxiv.org/abs/1805.09501)
13+
- [RandAugment: Practical automated data augmentation with a reduced search space](https://arxiv.org/abs/1909.13719)
14+
15+
Args:
16+
value_range: a tuple or a list of two elements. The first value
17+
represents the lower bound for values in passed images, the second
18+
represents the upper bound. Images passed to the layer should have
19+
values within `value_range`. Defaults to `(0, 255)`.
20+
factor: integer, the number of bits to keep for each channel. Must be a
21+
value between 1-8.
22+
"""
23+
24+
_USE_BASE_FACTOR = False
25+
_FACTOR_BOUNDS = (1, 8)
26+
_MAX_FACTOR = 8
27+
_VALUE_RANGE_VALIDATION_ERROR = (
28+
"The `value_range` argument should be a list of two numbers. "
29+
)
30+
31+
def __init__(
32+
self,
33+
factor,
34+
value_range=(0, 255),
35+
data_format=None,
36+
seed=None,
37+
**kwargs,
38+
):
39+
super().__init__(data_format=data_format, **kwargs)
40+
self._set_factor(factor)
41+
self._set_value_range(value_range)
42+
self.seed = seed
43+
self.generator = self.backend.random.SeedGenerator(seed)
44+
45+
def _set_value_range(self, value_range):
46+
if not isinstance(value_range, (tuple, list)):
47+
raise ValueError(
48+
self._VALUE_RANGE_VALIDATION_ERROR
49+
+ f"Received: value_range={value_range}"
50+
)
51+
if len(value_range) != 2:
52+
raise ValueError(
53+
self._VALUE_RANGE_VALIDATION_ERROR
54+
+ f"Received: value_range={value_range}"
55+
)
56+
self.value_range = sorted(value_range)
57+
58+
def get_random_transformation(self, data, training=True, seed=None):
59+
if isinstance(data, dict):
60+
images = data["images"]
61+
else:
62+
images = data
63+
images_shape = self.backend.shape(images)
64+
rank = len(images_shape)
65+
if rank == 3:
66+
batch_size = 1
67+
elif rank == 4:
68+
batch_size = images_shape[0]
69+
else:
70+
raise ValueError(
71+
"Expected the input image to be rank 3 or 4. Received: "
72+
f"inputs.shape={images_shape}"
73+
)
74+
75+
if seed is None:
76+
seed = self._get_seed_generator(self.backend._backend)
77+
78+
if self.factor[0] != self.factor[1]:
79+
factor = self.backend.random.randint(
80+
(batch_size,),
81+
minval=self.factor[0],
82+
maxval=self.factor[1],
83+
seed=seed,
84+
dtype="uint8",
85+
)
86+
else:
87+
factor = (
88+
self.backend.numpy.ones((batch_size,), dtype="uint8")
89+
* self.factor[0]
90+
)
91+
92+
shift_factor = self._MAX_FACTOR - factor
93+
return {"shift_factor": shift_factor}
94+
95+
def transform_images(self, images, transformation=None, training=True):
96+
if training:
97+
shift_factor = transformation["shift_factor"]
98+
99+
shift_factor = self.backend.numpy.reshape(
100+
shift_factor, self.backend.shape(shift_factor) + (1, 1, 1)
101+
)
102+
103+
images = self._transform_value_range(
104+
images,
105+
original_range=self.value_range,
106+
target_range=(0, 255),
107+
dtype=self.compute_dtype,
108+
)
109+
110+
images = self.backend.cast(images, "uint8")
111+
images = self.backend.numpy.bitwise_left_shift(
112+
self.backend.numpy.bitwise_right_shift(images, shift_factor),
113+
shift_factor,
114+
)
115+
images = self.backend.cast(images, self.compute_dtype)
116+
117+
images = self._transform_value_range(
118+
images,
119+
original_range=(0, 255),
120+
target_range=self.value_range,
121+
dtype=self.compute_dtype,
122+
)
123+
124+
return images
125+
126+
def transform_labels(self, labels, transformation, training=True):
127+
return labels
128+
129+
def transform_segmentation_masks(
130+
self, segmentation_masks, transformation, training=True
131+
):
132+
return segmentation_masks
133+
134+
def transform_bounding_boxes(
135+
self, bounding_boxes, transformation, training=True
136+
):
137+
return bounding_boxes
138+
139+
def get_config(self):
140+
config = super().get_config()
141+
config.update(
142+
{
143+
"factor": self.factor,
144+
"value_range": self.value_range,
145+
"seed": self.seed,
146+
}
147+
)
148+
return config
149+
150+
def compute_output_shape(self, input_shape):
151+
return input_shape
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import numpy as np
2+
import pytest
3+
from tensorflow import data as tf_data
4+
5+
import keras
6+
from keras.src import backend
7+
from keras.src import layers
8+
from keras.src import testing
9+
10+
11+
class RandomPosterizationTest(testing.TestCase):
12+
@pytest.mark.requires_trainable_backend
13+
def test_layer(self):
14+
self.run_layer_test(
15+
layers.RandomPosterization,
16+
init_kwargs={
17+
"factor": 1,
18+
"value_range": (20, 200),
19+
"seed": 1,
20+
},
21+
input_shape=(8, 3, 4, 3),
22+
supports_masking=False,
23+
expected_output_shape=(8, 3, 4, 3),
24+
)
25+
26+
def test_random_posterization_inference(self):
27+
seed = 3481
28+
layer = layers.RandomPosterization(1, [0, 255])
29+
np.random.seed(seed)
30+
inputs = np.random.randint(0, 255, size=(224, 224, 3))
31+
output = layer(inputs, training=False)
32+
self.assertAllClose(inputs, output)
33+
34+
def test_random_posterization_basic(self):
35+
seed = 3481
36+
layer = layers.RandomPosterization(
37+
1, [0, 255], data_format="channels_last", seed=seed
38+
)
39+
np.random.seed(seed)
40+
inputs = np.asarray(
41+
[[[128.0, 235.0, 87.0], [12.0, 1.0, 23.0], [24.0, 18.0, 121.0]]]
42+
)
43+
output = layer(inputs)
44+
expected_output = np.asarray(
45+
[[[128.0, 128.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]
46+
)
47+
self.assertAllClose(expected_output, output)
48+
49+
def test_random_posterization_value_range_0_to_1(self):
50+
image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)
51+
52+
layer = layers.RandomPosterization(1, [0, 1.0])
53+
adjusted_image = layer(image)
54+
55+
self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0))
56+
self.assertTrue(keras.ops.numpy.all(adjusted_image <= 1))
57+
58+
def test_random_posterization_value_range_0_to_255(self):
59+
image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=255)
60+
61+
layer = layers.RandomPosterization(1, [0, 255])
62+
adjusted_image = layer(image)
63+
64+
self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0))
65+
self.assertTrue(keras.ops.numpy.all(adjusted_image <= 255))
66+
67+
def test_random_posterization_randomness(self):
68+
image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)
69+
70+
layer = layers.RandomPosterization(1, [0, 255])
71+
adjusted_images = layer(image)
72+
73+
self.assertNotAllClose(adjusted_images, image)
74+
75+
def test_tf_data_compatibility(self):
76+
data_format = backend.config.image_data_format()
77+
if data_format == "channels_last":
78+
input_data = np.random.random((2, 8, 8, 3))
79+
else:
80+
input_data = np.random.random((2, 3, 8, 8))
81+
layer = layers.RandomPosterization(1, [0, 255])
82+
83+
ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
84+
for output in ds.take(1):
85+
output.numpy()

0 commit comments

Comments
 (0)