Skip to content

Commit 6ce93a4

Browse files
authored
Add random_sharpness processing layer (#20697)
* Add random_sharpness.py * Update random_sharpness * Add test cases * Fix failed test case
1 parent 2b073b6 commit 6ce93a4

File tree

5 files changed

+242
-0
lines changed

5 files changed

+242
-0
lines changed

keras/api/_tf_keras/keras/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,9 @@
185185
from keras.src.layers.preprocessing.image_preprocessing.random_saturation import (
186186
RandomSaturation,
187187
)
188+
from keras.src.layers.preprocessing.image_preprocessing.random_sharpness import (
189+
RandomSharpness,
190+
)
188191
from keras.src.layers.preprocessing.image_preprocessing.random_translation import (
189192
RandomTranslation,
190193
)

keras/api/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,9 @@
185185
from keras.src.layers.preprocessing.image_preprocessing.random_saturation import (
186186
RandomSaturation,
187187
)
188+
from keras.src.layers.preprocessing.image_preprocessing.random_sharpness import (
189+
RandomSharpness,
190+
)
188191
from keras.src.layers.preprocessing.image_preprocessing.random_translation import (
189192
RandomTranslation,
190193
)

keras/src/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@
129129
from keras.src.layers.preprocessing.image_preprocessing.random_saturation import (
130130
RandomSaturation,
131131
)
132+
from keras.src.layers.preprocessing.image_preprocessing.random_sharpness import (
133+
RandomSharpness,
134+
)
132135
from keras.src.layers.preprocessing.image_preprocessing.random_translation import (
133136
RandomTranslation,
134137
)
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
from keras.src.api_export import keras_export
2+
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
3+
BaseImagePreprocessingLayer,
4+
)
5+
from keras.src.random import SeedGenerator
6+
7+
8+
@keras_export("keras.layers.RandomSharpness")
9+
class RandomSharpness(BaseImagePreprocessingLayer):
10+
"""Randomly performs the sharpness operation on given images.
11+
12+
The sharpness operation first performs a blur, then blends between the
13+
original image and the processed image. This operation adjusts the clarity
14+
of the edges in an image, ranging from blurred to enhanced sharpness.
15+
16+
Args:
17+
factor: A tuple of two floats or a single float.
18+
`factor` controls the extent to which the image sharpness
19+
is impacted. `factor=0.0` results in a fully blurred image,
20+
`factor=0.5` applies no operation (preserving the original image),
21+
and `factor=1.0` enhances the sharpness beyond the original. Values
22+
should be between `0.0` and `1.0`. If a tuple is used, a `factor`
23+
is sampled between the two values for every image augmented.
24+
If a single float is used, a value between `0.0` and the passed
25+
float is sampled. To ensure the value is always the same,
26+
pass a tuple with two identical floats: `(0.5, 0.5)`.
27+
value_range: the range of values the incoming images will have.
28+
Represented as a two-number tuple written `[low, high]`. This is
29+
typically either `[0, 1]` or `[0, 255]` depending on how your
30+
preprocessing pipeline is set up.
31+
seed: Integer. Used to create a random seed.
32+
"""
33+
34+
_USE_BASE_FACTOR = False
35+
_FACTOR_BOUNDS = (0, 1)
36+
37+
_VALUE_RANGE_VALIDATION_ERROR = (
38+
"The `value_range` argument should be a list of two numbers. "
39+
)
40+
41+
def __init__(
42+
self,
43+
factor,
44+
value_range=(0, 255),
45+
data_format=None,
46+
seed=None,
47+
**kwargs,
48+
):
49+
super().__init__(data_format=data_format, **kwargs)
50+
self._set_factor(factor)
51+
self._set_value_range(value_range)
52+
self.seed = seed
53+
self.generator = SeedGenerator(seed)
54+
55+
def _set_value_range(self, value_range):
56+
if not isinstance(value_range, (tuple, list)):
57+
raise ValueError(
58+
self._VALUE_RANGE_VALIDATION_ERROR
59+
+ f"Received: value_range={value_range}"
60+
)
61+
if len(value_range) != 2:
62+
raise ValueError(
63+
self._VALUE_RANGE_VALIDATION_ERROR
64+
+ f"Received: value_range={value_range}"
65+
)
66+
self.value_range = sorted(value_range)
67+
68+
def get_random_transformation(self, data, training=True, seed=None):
69+
if isinstance(data, dict):
70+
images = data["images"]
71+
else:
72+
images = data
73+
images_shape = self.backend.shape(images)
74+
rank = len(images_shape)
75+
if rank == 3:
76+
batch_size = 1
77+
elif rank == 4:
78+
batch_size = images_shape[0]
79+
else:
80+
raise ValueError(
81+
"Expected the input image to be rank 3 or 4. Received: "
82+
f"inputs.shape={images_shape}"
83+
)
84+
85+
if seed is None:
86+
seed = self._get_seed_generator(self.backend._backend)
87+
88+
factor = self.backend.random.uniform(
89+
(batch_size,),
90+
minval=self.factor[0],
91+
maxval=self.factor[1],
92+
seed=seed,
93+
)
94+
return {"factor": factor}
95+
96+
def transform_images(self, images, transformation=None, training=True):
97+
images = self.backend.cast(images, self.compute_dtype)
98+
if training:
99+
if self.data_format == "channels_first":
100+
images = self.backend.numpy.swapaxes(images, -3, -1)
101+
102+
sharpness_factor = self.backend.cast(
103+
transformation["factor"] * 2, dtype=self.compute_dtype
104+
)
105+
sharpness_factor = self.backend.numpy.reshape(
106+
sharpness_factor, (-1, 1, 1, 1)
107+
)
108+
109+
num_channels = self.backend.shape(images)[-1]
110+
111+
a, b = 1.0 / 13.0, 5.0 / 13.0
112+
kernel = self.backend.convert_to_tensor(
113+
[[a, a, a], [a, b, a], [a, a, a]], dtype=self.compute_dtype
114+
)
115+
kernel = self.backend.numpy.reshape(kernel, (3, 3, 1, 1))
116+
kernel = self.backend.numpy.tile(kernel, [1, 1, num_channels, 1])
117+
kernel = self.backend.cast(kernel, self.compute_dtype)
118+
119+
smoothed_image = self.backend.nn.depthwise_conv(
120+
images,
121+
kernel,
122+
strides=1,
123+
padding="same",
124+
data_format="channels_last",
125+
)
126+
127+
smoothed_image = self.backend.cast(
128+
smoothed_image, dtype=self.compute_dtype
129+
)
130+
images = images + (1.0 - sharpness_factor) * (
131+
smoothed_image - images
132+
)
133+
134+
images = self.backend.numpy.clip(
135+
images, self.value_range[0], self.value_range[1]
136+
)
137+
138+
if self.data_format == "channels_first":
139+
images = self.backend.numpy.swapaxes(images, -3, -1)
140+
141+
return images
142+
143+
def transform_labels(self, labels, transformation, training=True):
144+
return labels
145+
146+
def transform_segmentation_masks(
147+
self, segmentation_masks, transformation, training=True
148+
):
149+
return segmentation_masks
150+
151+
def transform_bounding_boxes(
152+
self, bounding_boxes, transformation, training=True
153+
):
154+
return bounding_boxes
155+
156+
def get_config(self):
157+
config = super().get_config()
158+
config.update(
159+
{
160+
"factor": self.factor,
161+
"value_range": self.value_range,
162+
"seed": self.seed,
163+
}
164+
)
165+
return config
166+
167+
def compute_output_shape(self, input_shape):
168+
return input_shape
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import numpy as np
2+
import pytest
3+
from tensorflow import data as tf_data
4+
5+
import keras
6+
from keras.src import backend
7+
from keras.src import layers
8+
from keras.src import testing
9+
10+
11+
class RandomSharpnessTest(testing.TestCase):
12+
@pytest.mark.requires_trainable_backend
13+
def test_layer(self):
14+
self.run_layer_test(
15+
layers.RandomSharpness,
16+
init_kwargs={
17+
"factor": 0.75,
18+
"seed": 1,
19+
},
20+
input_shape=(8, 3, 4, 3),
21+
supports_masking=False,
22+
expected_output_shape=(8, 3, 4, 3),
23+
)
24+
25+
def test_random_sharpness_value_range(self):
26+
image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)
27+
28+
layer = layers.RandomSharpness(0.2)
29+
adjusted_image = layer(image)
30+
31+
self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0))
32+
self.assertTrue(keras.ops.numpy.all(adjusted_image <= 1))
33+
34+
def test_random_sharpness_no_op(self):
35+
data_format = backend.config.image_data_format()
36+
if data_format == "channels_last":
37+
inputs = np.random.random((2, 8, 8, 3))
38+
else:
39+
inputs = np.random.random((2, 3, 8, 8))
40+
41+
layer = layers.RandomSharpness((0.5, 0.5))
42+
output = layer(inputs, training=False)
43+
self.assertAllClose(inputs, output, atol=1e-3, rtol=1e-5)
44+
45+
def test_random_sharpness_randomness(self):
46+
image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)[:5]
47+
48+
layer = layers.RandomSharpness(0.2)
49+
adjusted_images = layer(image)
50+
51+
self.assertNotAllClose(adjusted_images, image)
52+
53+
def test_tf_data_compatibility(self):
54+
data_format = backend.config.image_data_format()
55+
if data_format == "channels_last":
56+
input_data = np.random.random((2, 8, 8, 3))
57+
else:
58+
input_data = np.random.random((2, 3, 8, 8))
59+
layer = layers.RandomSharpness(
60+
factor=0.5, data_format=data_format, seed=1337
61+
)
62+
63+
ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
64+
for output in ds.take(1):
65+
output.numpy()

0 commit comments

Comments
 (0)