Skip to content

Commit 3c9fee7

Browse files
authored
Add random_shear processing layer (#20702)
* Add random_shear processing layer * Update method name * Fix failed test case * Fix failed test case * Fix failed test case
1 parent 6ce93a4 commit 3c9fee7

File tree

5 files changed

+348
-0
lines changed

5 files changed

+348
-0
lines changed

keras/api/_tf_keras/keras/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,9 @@
188188
from keras.src.layers.preprocessing.image_preprocessing.random_sharpness import (
189189
RandomSharpness,
190190
)
191+
from keras.src.layers.preprocessing.image_preprocessing.random_shear import (
192+
RandomShear,
193+
)
191194
from keras.src.layers.preprocessing.image_preprocessing.random_translation import (
192195
RandomTranslation,
193196
)

keras/api/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,9 @@
188188
from keras.src.layers.preprocessing.image_preprocessing.random_sharpness import (
189189
RandomSharpness,
190190
)
191+
from keras.src.layers.preprocessing.image_preprocessing.random_shear import (
192+
RandomShear,
193+
)
191194
from keras.src.layers.preprocessing.image_preprocessing.random_translation import (
192195
RandomTranslation,
193196
)

keras/src/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@
132132
from keras.src.layers.preprocessing.image_preprocessing.random_sharpness import (
133133
RandomSharpness,
134134
)
135+
from keras.src.layers.preprocessing.image_preprocessing.random_shear import (
136+
RandomShear,
137+
)
135138
from keras.src.layers.preprocessing.image_preprocessing.random_translation import (
136139
RandomTranslation,
137140
)
Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
from keras.src.api_export import keras_export
2+
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
3+
BaseImagePreprocessingLayer,
4+
)
5+
from keras.src.random.seed_generator import SeedGenerator
6+
7+
8+
@keras_export("keras.layers.RandomShear")
9+
class RandomShear(BaseImagePreprocessingLayer):
10+
"""A preprocessing layer that randomly applies shear transformations to
11+
images.
12+
13+
This layer shears the input images along the x-axis and/or y-axis by a
14+
randomly selected factor within the specified range. The shear
15+
transformation is applied to each image independently in a batch. Empty
16+
regions created during the transformation are filled according to the
17+
`fill_mode` and `fill_value` parameters.
18+
19+
Args:
20+
x_factor: A tuple of two floats. For each augmented image, a value
21+
is sampled from the provided range. If a float is passed, the
22+
range is interpreted as `(0, x_factor)`. Values represent a
23+
percentage of the image to shear over. For example, 0.3 shears
24+
pixels up to 30% of the way across the image. All provided values
25+
should be positive.
26+
y_factor: A tuple of two floats. For each augmented image, a value
27+
is sampled from the provided range. If a float is passed, the
28+
range is interpreted as `(0, y_factor)`. Values represent a
29+
percentage of the image to shear over. For example, 0.3 shears
30+
pixels up to 30% of the way across the image. All provided values
31+
should be positive.
32+
interpolation: Interpolation mode. Supported values: `"nearest"`,
33+
`"bilinear"`.
34+
fill_mode: Points outside the boundaries of the input are filled
35+
according to the given mode. Available methods are `"constant"`,
36+
`"nearest"`, `"wrap"` and `"reflect"`. Defaults to `"constant"`.
37+
- `"reflect"`: `(d c b a | a b c d | d c b a)`
38+
The input is extended by reflecting about the edge of the
39+
last pixel.
40+
- `"constant"`: `(k k k k | a b c d | k k k k)`
41+
The input is extended by filling all values beyond the edge
42+
with the same constant value `k` specified by `fill_value`.
43+
- `"wrap"`: `(a b c d | a b c d | a b c d)`
44+
The input is extended by wrapping around to the opposite edge.
45+
- `"nearest"`: `(a a a a | a b c d | d d d d)`
46+
The input is extended by the nearest pixel.
47+
Note that when using torch backend, `"reflect"` is redirected to
48+
`"mirror"` `(c d c b | a b c d | c b a b)` because torch does
49+
not support `"reflect"`.
50+
Note that torch backend does not support `"wrap"`.
51+
fill_value: A float representing the value to be filled outside the
52+
boundaries when `fill_mode="constant"`.
53+
seed: Integer. Used to create a random seed.
54+
"""
55+
56+
_USE_BASE_FACTOR = False
57+
_FACTOR_BOUNDS = (0, 1)
58+
_FACTOR_VALIDATION_ERROR = (
59+
"The `factor` argument should be a number (or a list of two numbers) "
60+
"in the range [0, 1.0]. "
61+
)
62+
_SUPPORTED_FILL_MODE = ("reflect", "wrap", "constant", "nearest")
63+
_SUPPORTED_INTERPOLATION = ("nearest", "bilinear")
64+
65+
def __init__(
66+
self,
67+
x_factor=0.0,
68+
y_factor=0.0,
69+
interpolation="bilinear",
70+
fill_mode="reflect",
71+
fill_value=0.0,
72+
data_format=None,
73+
seed=None,
74+
**kwargs,
75+
):
76+
super().__init__(data_format=data_format, **kwargs)
77+
self.x_factor = self._set_factor_with_name(x_factor, "x_factor")
78+
self.y_factor = self._set_factor_with_name(y_factor, "y_factor")
79+
80+
if fill_mode not in self._SUPPORTED_FILL_MODE:
81+
raise NotImplementedError(
82+
f"Unknown `fill_mode` {fill_mode}. Expected of one "
83+
f"{self._SUPPORTED_FILL_MODE}."
84+
)
85+
if interpolation not in self._SUPPORTED_INTERPOLATION:
86+
raise NotImplementedError(
87+
f"Unknown `interpolation` {interpolation}. Expected of one "
88+
f"{self._SUPPORTED_INTERPOLATION}."
89+
)
90+
91+
self.fill_mode = fill_mode
92+
self.fill_value = fill_value
93+
self.interpolation = interpolation
94+
self.seed = seed
95+
self.generator = SeedGenerator(seed)
96+
self.supports_jit = False
97+
98+
def _set_factor_with_name(self, factor, factor_name):
99+
if isinstance(factor, (tuple, list)):
100+
if len(factor) != 2:
101+
raise ValueError(
102+
self._FACTOR_VALIDATION_ERROR
103+
+ f"Received: {factor_name}={factor}"
104+
)
105+
self._check_factor_range(factor[0])
106+
self._check_factor_range(factor[1])
107+
lower, upper = sorted(factor)
108+
elif isinstance(factor, (int, float)):
109+
self._check_factor_range(factor)
110+
factor = abs(factor)
111+
lower, upper = [-factor, factor]
112+
else:
113+
raise ValueError(
114+
self._FACTOR_VALIDATION_ERROR
115+
+ f"Received: {factor_name}={factor}"
116+
)
117+
return lower, upper
118+
119+
def _check_factor_range(self, input_number):
120+
if input_number > 1.0 or input_number < 0.0:
121+
raise ValueError(
122+
self._FACTOR_VALIDATION_ERROR
123+
+ f"Received: input_number={input_number}"
124+
)
125+
126+
def get_random_transformation(self, data, training=True, seed=None):
127+
if not training:
128+
return None
129+
130+
if isinstance(data, dict):
131+
images = data["images"]
132+
else:
133+
images = data
134+
135+
images_shape = self.backend.shape(images)
136+
if len(images_shape) == 3:
137+
batch_size = 1
138+
else:
139+
batch_size = images_shape[0]
140+
141+
if seed is None:
142+
seed = self._get_seed_generator(self.backend._backend)
143+
144+
invert = self.backend.random.uniform(
145+
minval=0,
146+
maxval=1,
147+
shape=[batch_size, 1],
148+
seed=seed,
149+
dtype=self.compute_dtype,
150+
)
151+
invert = self.backend.numpy.where(
152+
invert > 0.5,
153+
-self.backend.numpy.ones_like(invert),
154+
self.backend.numpy.ones_like(invert),
155+
)
156+
157+
shear_y = self.backend.random.uniform(
158+
minval=self.y_factor[0],
159+
maxval=self.y_factor[1],
160+
shape=[batch_size, 1],
161+
seed=seed,
162+
dtype=self.compute_dtype,
163+
)
164+
shear_x = self.backend.random.uniform(
165+
minval=self.x_factor[0],
166+
maxval=self.x_factor[1],
167+
shape=[batch_size, 1],
168+
seed=seed,
169+
dtype=self.compute_dtype,
170+
)
171+
shear_factor = (
172+
self.backend.cast(
173+
self.backend.numpy.concatenate([shear_x, shear_y], axis=1),
174+
dtype=self.compute_dtype,
175+
)
176+
* invert
177+
)
178+
return {"shear_factor": shear_factor}
179+
180+
def transform_images(self, images, transformation, training=True):
181+
images = self.backend.cast(images, self.compute_dtype)
182+
if training:
183+
return self._shear_inputs(images, transformation)
184+
return images
185+
186+
def _shear_inputs(self, inputs, transformation):
187+
if transformation is None:
188+
return inputs
189+
190+
inputs_shape = self.backend.shape(inputs)
191+
unbatched = len(inputs_shape) == 3
192+
if unbatched:
193+
inputs = self.backend.numpy.expand_dims(inputs, axis=0)
194+
195+
shear_factor = transformation["shear_factor"]
196+
outputs = self.backend.image.affine_transform(
197+
inputs,
198+
transform=self._get_shear_matrix(shear_factor),
199+
interpolation=self.interpolation,
200+
fill_mode=self.fill_mode,
201+
fill_value=self.fill_value,
202+
data_format=self.data_format,
203+
)
204+
205+
if unbatched:
206+
outputs = self.backend.numpy.squeeze(outputs, axis=0)
207+
return outputs
208+
209+
def _get_shear_matrix(self, shear_factors):
210+
num_shear_factors = self.backend.shape(shear_factors)[0]
211+
212+
# The shear matrix looks like:
213+
# [[1 s_x 0]
214+
# [s_y 1 0]
215+
# [0 0 1]]
216+
217+
return self.backend.numpy.stack(
218+
[
219+
self.backend.numpy.ones((num_shear_factors,)),
220+
shear_factors[:, 0],
221+
self.backend.numpy.zeros((num_shear_factors,)),
222+
shear_factors[:, 1],
223+
self.backend.numpy.ones((num_shear_factors,)),
224+
self.backend.numpy.zeros((num_shear_factors,)),
225+
self.backend.numpy.zeros((num_shear_factors,)),
226+
self.backend.numpy.zeros((num_shear_factors,)),
227+
],
228+
axis=1,
229+
)
230+
231+
def transform_labels(self, labels, transformation, training=True):
232+
return labels
233+
234+
def transform_bounding_boxes(
235+
self,
236+
bounding_boxes,
237+
transformation,
238+
training=True,
239+
):
240+
raise NotImplementedError
241+
242+
def transform_segmentation_masks(
243+
self, segmentation_masks, transformation, training=True
244+
):
245+
return self.transform_images(
246+
segmentation_masks, transformation, training=training
247+
)
248+
249+
def get_config(self):
250+
base_config = super().get_config()
251+
config = {
252+
"x_factor": self.x_factor,
253+
"y_factor": self.y_factor,
254+
"fill_mode": self.fill_mode,
255+
"interpolation": self.interpolation,
256+
"seed": self.seed,
257+
"fill_value": self.fill_value,
258+
"data_format": self.data_format,
259+
}
260+
return {**base_config, **config}
261+
262+
def compute_output_shape(self, input_shape):
263+
return input_shape
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import numpy as np
2+
import pytest
3+
from tensorflow import data as tf_data
4+
5+
import keras
6+
from keras.src import backend
7+
from keras.src import layers
8+
from keras.src import testing
9+
10+
11+
class RandomShearTest(testing.TestCase):
12+
@pytest.mark.requires_trainable_backend
13+
def test_layer(self):
14+
self.run_layer_test(
15+
layers.RandomShear,
16+
init_kwargs={
17+
"x_factor": (0.5, 1),
18+
"y_factor": (0.5, 1),
19+
"interpolation": "bilinear",
20+
"fill_mode": "reflect",
21+
"data_format": "channels_last",
22+
"seed": 1,
23+
},
24+
input_shape=(8, 3, 4, 3),
25+
supports_masking=False,
26+
expected_output_shape=(8, 3, 4, 3),
27+
)
28+
29+
def test_random_posterization_inference(self):
30+
seed = 3481
31+
layer = layers.RandomShear(1, 1)
32+
np.random.seed(seed)
33+
inputs = np.random.randint(0, 255, size=(224, 224, 3))
34+
output = layer(inputs, training=False)
35+
self.assertAllClose(inputs, output)
36+
37+
def test_shear_pixel_level(self):
38+
image = np.zeros((1, 5, 5, 3))
39+
image[0, 1:4, 1:4, :] = 1.0
40+
image[0, 2, 2, :] = [0.0, 1.0, 0.0]
41+
image = keras.ops.convert_to_tensor(image, dtype="float32")
42+
43+
data_format = backend.config.image_data_format()
44+
if data_format == "channels_first":
45+
image = keras.ops.transpose(image, (0, 3, 1, 2))
46+
47+
shear_layer = layers.RandomShear(
48+
x_factor=(0.2, 0.3),
49+
y_factor=(0.2, 0.3),
50+
interpolation="bilinear",
51+
fill_mode="constant",
52+
fill_value=0.0,
53+
seed=42,
54+
data_format=data_format,
55+
)
56+
57+
sheared_image = shear_layer(image)
58+
59+
if data_format == "channels_first":
60+
sheared_image = keras.ops.transpose(sheared_image, (0, 2, 3, 1))
61+
62+
original_pixel = image[0, 2, 2, :]
63+
sheared_pixel = sheared_image[0, 2, 2, :]
64+
self.assertNotAllClose(original_pixel, sheared_pixel)
65+
66+
def test_tf_data_compatibility(self):
67+
data_format = backend.config.image_data_format()
68+
if data_format == "channels_last":
69+
input_data = np.random.random((2, 8, 8, 3))
70+
else:
71+
input_data = np.random.random((2, 3, 8, 8))
72+
layer = layers.RandomShear(1, 1)
73+
74+
ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
75+
for output in ds.take(1):
76+
output.numpy()

0 commit comments

Comments
 (0)