Skip to content

Commit 0d3ba37

Browse files
shashakaIMvision12edge7mmicurkazants
authored
Add random_color_jitter processing layer (#20673)
* Add implementations for random_saturation * change parse_factor method to inner method. * Add implementations for random_color_jitter * Fix Randomhue (#20652) * Small fix in random hue * use self.backend for seed * test: add test for class weights (py_dataset adapter) (#20638) * test: add test for class weights (py_dataset adapter) * "call _standardize_batch from enqueuer" m * add more tests, handle pytorch astype issue m * convert to numpy to ensure consistent handling of operations * Fix paths for pytest in contribution guide (#20655) * Add preliminary support of OpenVINO as Keras 3 backend (#19727) * [POC][OV] Support OpenVINO as Keras 3 backend Signed-off-by: Kazantsev, Roman <[email protected]> * Mark all unsupported ops from numpy space Signed-off-by: Kazantsev, Roman <[email protected]> * Mark unsupported ops in core, image, and linalg spaces Signed-off-by: Kazantsev, Roman <[email protected]> * Mark unsupported ops in math, nn, random, and rnn spaces Signed-off-by: Kazantsev, Roman <[email protected]> * Fix sorting imports Signed-off-by: Kazantsev, Roman <[email protected]> * Format imports Signed-off-by: Kazantsev, Roman <[email protected]> * Fix sorting imports Signed-off-by: Kazantsev, Roman <[email protected]> * Fix sorting imports Signed-off-by: Kazantsev, Roman <[email protected]> * Fix inference Signed-off-by: Kazantsev, Roman <[email protected]> * Remove openvino specific code in common part Signed-off-by: Kazantsev, Roman <[email protected]> * Fix typo * Clean-up code Signed-off-by: Kazantsev, Roman <[email protected]> * Recover imports Signed-off-by: Kazantsev, Roman <[email protected]> * Sort imports properly Signed-off-by: Kazantsev, Roman <[email protected]> * Format source code Signed-off-by: Kazantsev, Roman <[email protected]> * Format the rest of source code Signed-off-by: Kazantsev, Roman <[email protected]> * Continue format adjustment Signed-off-by: Kazantsev, Roman <[email protected]> * Add OpenVINO dependency Signed-off-by: Kazantsev, Roman <[email protected]> * Fix inference using OV backend Signed-off-by: Kazantsev, Roman <[email protected]> * Support bert_base_en_uncased and mobilenet_v3_small from Keras Hub Signed-off-by: Kazantsev, Roman <[email protected]> * Remove extra openvino specific code from layer.py Signed-off-by: Kazantsev, Roman <[email protected]> * Apply code-style formatting Signed-off-by: Kazantsev, Roman <[email protected]> * Apply code-style formatting Signed-off-by: Kazantsev, Roman <[email protected]> * Fix remained code-style issue Signed-off-by: Kazantsev, Roman <[email protected]> * Run tests for OpenVINO backend in GHA Signed-off-by: Kazantsev, Roman <[email protected]> * Add config file for openvino backend validation Signed-off-by: Kazantsev, Roman <[email protected]> * Add import test for openvino backend Signed-off-by: Kazantsev, Roman <[email protected]> * Fix error in import_test.py Signed-off-by: Kazantsev, Roman <[email protected]> * Add import_test for openvino backend Signed-off-by: Kazantsev, Roman <[email protected]> * Add openvino specific integration tests in GHA Signed-off-by: Kazantsev, Roman <[email protected]> * Exclude coverage for OpenVINO Signed-off-by: Kazantsev, Roman <[email protected]> * remove coverage for openvino backend Signed-off-by: Kazantsev, Roman <[email protected]> * Try layer tests for openvino backend Signed-off-by: Kazantsev, Roman <[email protected]> * Run layer tests for openvino backend selectively Signed-off-by: Kazantsev, Roman <[email protected]> * Mark enabled tests for openvino backend in a different way Signed-off-by: Kazantsev, Roman <[email protected]> * Update .github/workflows/actions.yml * Fix import for BackendVariable Signed-off-by: Kazantsev, Roman <[email protected]> * Fix errors in layer tests for openvino backend Signed-off-by: Kazantsev, Roman <[email protected]> * Add test for Elu via openvino backend Signed-off-by: Kazantsev, Roman <[email protected]> * Fix sorted imports Signed-off-by: Kazantsev, Roman <[email protected]> * Extend testing for attention Signed-off-by: Kazantsev, Roman <[email protected]> * Update keras/src/layers/attention/attention_test.py * Switch on activation tests for openvino backend Signed-off-by: Kazantsev, Roman <[email protected]> * Switch on attention tests for openvino backend Signed-off-by: Kazantsev, Roman <[email protected]> * Update keras/src/layers/attention/additive_attention_test.py * Update keras/src/layers/attention/grouped_query_attention_test.py * Run conv tests for openvino backend Signed-off-by: Kazantsev, Roman <[email protected]> * Fix convolution in openvino backend Signed-off-by: Kazantsev, Roman <[email protected]> * Work around constant creation for tuple Signed-off-by: Kazantsev, Roman <[email protected]> * Work around constant creation in reshape Signed-off-by: Kazantsev, Roman <[email protected]> * Run depthwise conv tests for openvino backend Signed-off-by: Kazantsev, Roman <[email protected]> * Fix get_ov_output for other x types Signed-off-by: Kazantsev, Roman <[email protected]> * Fix elu translation Signed-off-by: Kazantsev, Roman <[email protected]> * Fix softmax and log_softmax for None axis Signed-off-by: Kazantsev, Roman <[email protected]> * Run nn tests for openvino backend Signed-off-by: Kazantsev, Roman <[email protected]> * Fix numpy operations for axis to be None Signed-off-by: Kazantsev, Roman <[email protected]> * Run operation_test for openvino_backend Signed-off-by: Kazantsev, Roman <[email protected]> * Switch on math_test for openvino backend Signed-off-by: Kazantsev, Roman <[email protected]> * Switch on image tests for openvino backend Signed-off-by: Kazantsev, Roman <[email protected]> * Switch on linalg test for openvino backend Signed-off-by: Kazantsev, Roman <[email protected]> * Extend OpenVINOKerasTensor with new built-in methods and fix shape op Signed-off-by: Kazantsev, Roman <[email protected]> * Switch on core tests for openvino backend Signed-off-by: Kazantsev, Roman <[email protected]> * Use different way of OpenVINO model creation that supports call method Signed-off-by: Kazantsev, Roman <[email protected]> * Unify integration test for openvino Signed-off-by: Kazantsev, Roman <[email protected]> * Support new operations abs, mod, etc. Signed-off-by: Kazantsev, Roman <[email protected]> * Add support for more operations like squeeze, max Signed-off-by: Kazantsev, Roman <[email protected]> * Try to use excluded test files list Signed-off-by: Kazantsev, Roman <[email protected]> * Apply formatting for normalization_test.py Signed-off-by: Kazantsev, Roman <[email protected]> * Correct GHA yml file Signed-off-by: Kazantsev, Roman <[email protected]> * Test that openvino backend is used Signed-off-by: Kazantsev, Roman <[email protected]> * Revert testing change in excluded test files list Signed-off-by: Kazantsev, Roman <[email protected]> * Include testing group Signed-off-by: Kazantsev, Roman <[email protected]> * Include legacy test group Signed-off-by: Kazantsev, Roman <[email protected]> * Exclude legacy group of tests Signed-off-by: Kazantsev, Roman <[email protected]> * Include initializers tests Signed-off-by: Kazantsev, Roman <[email protected]> * Skip tests for initializers group Signed-off-by: Kazantsev, Roman <[email protected]> * Remove export test group from ignore Signed-off-by: Kazantsev, Roman <[email protected]> * Include dtype_policies test group Signed-off-by: Kazantsev, Roman <[email protected]> * Reduce ignored tests Signed-off-by: Kazantsev, Roman <[email protected]> * Fix ops.cast Signed-off-by: Kazantsev, Roman <[email protected]> * Add decorator for custom_gradient Signed-off-by: Kazantsev, Roman <[email protected]> * Shorten line in custom_gradient Signed-off-by: Kazantsev, Roman <[email protected]> * Ignore dtype_policy_map test Signed-off-by: Kazantsev, Roman <[email protected]> * Include callback tests Signed-off-by: Kazantsev, Roman <[email protected]> * Switch on backend tests Signed-off-by: Kazantsev, Roman <[email protected]> * Exclude failing tests Signed-off-by: Kazantsev, Roman <[email protected]> * Correct paths to excluded tests Signed-off-by: Kazantsev, Roman <[email protected]> * Switch on some layers tests Signed-off-by: Kazantsev, Roman <[email protected]> * Remove pytest.mark.openvino_backend Signed-off-by: Kazantsev, Roman <[email protected]> * Register mark requires_trainable_backend Signed-off-by: Kazantsev, Roman <[email protected]> * Ignore test files in a different way Signed-off-by: Kazantsev, Roman <[email protected]> * Try different way to ignore test files Signed-off-by: Kazantsev, Roman <[email protected]> * Fix GHA yml Signed-off-by: Kazantsev, Roman <[email protected]> * Support tuple axis for logsumexp Signed-off-by: Kazantsev, Roman <[email protected]> * Switch on some ops tests Signed-off-by: Kazantsev, Roman <[email protected]> * Switch on some callbacks tests Signed-off-by: Kazantsev, Roman <[email protected]> * Add openvino export Signed-off-by: Kazantsev, Roman <[email protected]> * Update sklearn tests Signed-off-by: Kazantsev, Roman <[email protected]> * Add a comment to skipp numerical_test Signed-off-by: Kazantsev, Roman <[email protected]> * Add custom requirements file for OpenVINO Signed-off-by: Kazantsev, Roman <[email protected]> * Add reqs of openvino installation for api changes check Signed-off-by: Kazantsev, Roman <[email protected]> * Fix types of Variables and switch on some variables tests Signed-off-by: Kazantsev, Roman <[email protected]> * Fix nightly code check Signed-off-by: Kazantsev, Roman <[email protected]> --------- Signed-off-by: Kazantsev, Roman <[email protected]> * Make sklearn dependency optional (#20657) * Add a condition to verify training status during image processing (#20650) * Add a condition to verify training status during image processing * resolve merge conflict * fix transform_bounding_boxes logic * add transform_bounding_boxes test * Fix recurrent dropout for GRU. (#20656) The simplified implementation, which used the same recurrent dropout masks for all the previous states didn't work and caused the training to not converge with large enough recurrent dropout values. This new implementation is now the same as Keras 2. Note that recurrent dropout requires "implementation 1" to be turned on. Fixes #20276 * Fix example title in probabilistic_metrics.py (#20662) * Change recurrent dropout implementation for LSTM. (#20663) This change is to make the implementation of recurrent dropout consistent with GRU (changed as of #20656 ) and Keras 2. Also fixed a bug where the GRU fix would break when using CUDNN with a dropout and no recurrent dropout. The solution is to create multiple masks only when needed (implementation == 1). Added coverage for the case when dropout is set and recurrent dropout is not set. * Never pass enable_xla=False or native_serialization=False in tests (#20664) These are invalid options in the latest version of jax2tf, they will just immediately throw. * Fix `PyDatasetAdapterTest::test_class_weight` test with Torch on GPU. (#20665) The test was failing because arrays on device and on cpu were compared. * Fix up torch GPU failing test for mix up (#20666) We need to make sure to use get any tensors places on cpu before using them in the tensorflow backend during preprocessing. * Add random_color_jitter processing layer * Add random_color_jitter test * Update test cases * Correct failed test case * Correct failed test case * Correct failed test case --------- Signed-off-by: Kazantsev, Roman <[email protected]> Co-authored-by: IMvision12 <[email protected]> Co-authored-by: Enrico <[email protected]> Co-authored-by: Marco <[email protected]> Co-authored-by: Roman Kazantsev <[email protected]> Co-authored-by: Matt Watson <[email protected]> Co-authored-by: hertschuh <[email protected]> Co-authored-by: Jasmine Dhantule <[email protected]>
1 parent 7c491bd commit 0d3ba37

File tree

5 files changed

+341
-0
lines changed

5 files changed

+341
-0
lines changed

keras/api/_tf_keras/keras/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@
155155
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
156156
RandomBrightness,
157157
)
158+
from keras.src.layers.preprocessing.image_preprocessing.random_color_jitter import (
159+
RandomColorJitter,
160+
)
158161
from keras.src.layers.preprocessing.image_preprocessing.random_contrast import (
159162
RandomContrast,
160163
)

keras/api/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@
155155
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
156156
RandomBrightness,
157157
)
158+
from keras.src.layers.preprocessing.image_preprocessing.random_color_jitter import (
159+
RandomColorJitter,
160+
)
158161
from keras.src.layers.preprocessing.image_preprocessing.random_contrast import (
159162
RandomContrast,
160163
)

keras/src/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@
9999
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
100100
RandomBrightness,
101101
)
102+
from keras.src.layers.preprocessing.image_preprocessing.random_color_jitter import (
103+
RandomColorJitter,
104+
)
102105
from keras.src.layers.preprocessing.image_preprocessing.random_contrast import (
103106
RandomContrast,
104107
)
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
import keras.src.layers.preprocessing.image_preprocessing.random_brightness as random_brightness # noqa: E501
2+
import keras.src.layers.preprocessing.image_preprocessing.random_contrast as random_contrast # noqa: E501
3+
import keras.src.layers.preprocessing.image_preprocessing.random_hue as random_hue # noqa: E501
4+
import keras.src.layers.preprocessing.image_preprocessing.random_saturation as random_saturation # noqa: E501
5+
from keras.src.api_export import keras_export
6+
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
7+
BaseImagePreprocessingLayer,
8+
)
9+
from keras.src.random.seed_generator import SeedGenerator
10+
from keras.src.utils import backend_utils
11+
12+
13+
@keras_export("keras.layers.RandomColorJitter")
14+
class RandomColorJitter(BaseImagePreprocessingLayer):
15+
"""RandomColorJitter class randomly apply brightness, contrast, saturation
16+
and hue image processing operation sequentially and randomly on the
17+
input.
18+
19+
Args:
20+
value_range: the range of values the incoming images will have.
21+
Represented as a two number tuple written [low, high].
22+
This is typically either `[0, 1]` or `[0, 255]` depending
23+
on how your preprocessing pipeline is set up.
24+
brightness_factor: Float or a list/tuple of 2 floats between -1.0
25+
and 1.0. The factor is used to determine the lower bound and
26+
upper bound of the brightness adjustment. A float value will
27+
be chosen randomly between the limits. When -1.0 is chosen,
28+
the output image will be black, and when 1.0 is chosen, the
29+
image will be fully white. When only one float is provided,
30+
eg, 0.2, then -0.2 will be used for lower bound and 0.2 will
31+
be used for upper bound.
32+
contrast_factor: a positive float represented as fraction of value,
33+
or a tuple of size 2 representing lower and upper bound. When
34+
represented as a single float, lower = upper. The contrast
35+
factor will be randomly picked between `[1.0 - lower, 1.0 +
36+
upper]`. For any pixel x in the channel, the output will be
37+
`(x - mean) * factor + mean` where `mean` is the mean value
38+
of the channel.
39+
saturation_factor: A tuple of two floats or a single float. `factor`
40+
controls the extent to which the image saturation is impacted.
41+
`factor=0.5` makes this layer perform a no-op operation.
42+
`factor=0.0` makes the image fully grayscale. `factor=1.0`
43+
makes the image fully saturated. Values should be between
44+
`0.0` and `1.0`. If a tuple is used, a `factor` is sampled
45+
between the two values for every image augmented. If a single
46+
float is used, a value between `0.0` and the passed float is
47+
sampled. To ensure the value is always the same, pass a tuple
48+
with two identical floats: `(0.5, 0.5)`.
49+
hue_factor: A single float or a tuple of two floats. `factor`
50+
controls the extent to which the image hue is impacted.
51+
`factor=0.0` makes this layer perform a no-op operation,
52+
while a value of `1.0` performs the most aggressive contrast
53+
adjustment available. If a tuple is used, a `factor` is
54+
sampled between the two values for every image augmented.
55+
If a single float is used, a value between `0.0` and the
56+
passed float is sampled. In order to ensure the value is
57+
always the same, please pass a tuple with two identical
58+
floats: `(0.5, 0.5)`.
59+
seed: Integer. Used to create a random seed.
60+
"""
61+
62+
def __init__(
63+
self,
64+
value_range=(0, 255),
65+
brightness_factor=None,
66+
contrast_factor=None,
67+
saturation_factor=None,
68+
hue_factor=None,
69+
seed=None,
70+
data_format=None,
71+
**kwargs,
72+
):
73+
super().__init__(data_format=data_format, **kwargs)
74+
self.value_range = value_range
75+
self.brightness_factor = brightness_factor
76+
self.contrast_factor = contrast_factor
77+
self.saturation_factor = saturation_factor
78+
self.hue_factor = hue_factor
79+
self.seed = seed
80+
self.generator = SeedGenerator(seed)
81+
82+
self.random_brightness = None
83+
self.random_contrast = None
84+
self.random_saturation = None
85+
self.random_hue = None
86+
87+
if self.brightness_factor is not None:
88+
self.random_brightness = random_brightness.RandomBrightness(
89+
factor=self.brightness_factor,
90+
value_range=self.value_range,
91+
seed=self.seed,
92+
)
93+
94+
if self.contrast_factor is not None:
95+
self.random_contrast = random_contrast.RandomContrast(
96+
factor=self.contrast_factor,
97+
value_range=self.value_range,
98+
seed=self.seed,
99+
)
100+
101+
if self.saturation_factor is not None:
102+
self.random_saturation = random_saturation.RandomSaturation(
103+
factor=self.saturation_factor,
104+
value_range=self.value_range,
105+
seed=self.seed,
106+
)
107+
108+
if self.hue_factor is not None:
109+
self.random_hue = random_hue.RandomHue(
110+
factor=self.hue_factor,
111+
value_range=self.value_range,
112+
seed=self.seed,
113+
)
114+
115+
def transform_images(self, images, transformation, training=True):
116+
if training:
117+
if backend_utils.in_tf_graph():
118+
self.backend.set_backend("tensorflow")
119+
images = self.backend.cast(images, self.compute_dtype)
120+
if self.brightness_factor is not None:
121+
if backend_utils.in_tf_graph():
122+
self.random_brightness.backend.set_backend("tensorflow")
123+
transformation = (
124+
self.random_brightness.get_random_transformation(
125+
images,
126+
seed=self._get_seed_generator(self.backend._backend),
127+
)
128+
)
129+
images = self.random_brightness.transform_images(
130+
images, transformation
131+
)
132+
if self.contrast_factor is not None:
133+
if backend_utils.in_tf_graph():
134+
self.random_contrast.backend.set_backend("tensorflow")
135+
transformation = self.random_contrast.get_random_transformation(
136+
images, seed=self._get_seed_generator(self.backend._backend)
137+
)
138+
transformation["contrast_factor"] = self.backend.cast(
139+
transformation["contrast_factor"], dtype=self.compute_dtype
140+
)
141+
images = self.random_contrast.transform_images(
142+
images, transformation
143+
)
144+
if self.saturation_factor is not None:
145+
if backend_utils.in_tf_graph():
146+
self.random_saturation.backend.set_backend("tensorflow")
147+
transformation = (
148+
self.random_saturation.get_random_transformation(
149+
images,
150+
seed=self._get_seed_generator(self.backend._backend),
151+
)
152+
)
153+
images = self.random_saturation.transform_images(
154+
images, transformation
155+
)
156+
if self.hue_factor is not None:
157+
if backend_utils.in_tf_graph():
158+
self.random_hue.backend.set_backend("tensorflow")
159+
transformation = self.random_hue.get_random_transformation(
160+
images, seed=self._get_seed_generator(self.backend._backend)
161+
)
162+
images = self.random_hue.transform_images(
163+
images, transformation
164+
)
165+
images = self.backend.cast(images, self.compute_dtype)
166+
return images
167+
168+
def transform_labels(self, labels, transformation, training=True):
169+
return labels
170+
171+
def transform_bounding_boxes(
172+
self,
173+
bounding_boxes,
174+
transformation,
175+
training=True,
176+
):
177+
return bounding_boxes
178+
179+
def transform_segmentation_masks(
180+
self, segmentation_masks, transformation, training=True
181+
):
182+
return segmentation_masks
183+
184+
def compute_output_shape(self, input_shape):
185+
return input_shape
186+
187+
def get_config(self):
188+
config = {
189+
"value_range": self.value_range,
190+
"brightness_factor": self.brightness_factor,
191+
"contrast_factor": self.contrast_factor,
192+
"saturation_factor": self.saturation_factor,
193+
"hue_factor": self.hue_factor,
194+
"seed": self.seed,
195+
}
196+
base_config = super().get_config()
197+
return {**base_config, **config}
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import numpy as np
2+
import pytest
3+
from tensorflow import data as tf_data
4+
5+
from keras.src import backend
6+
from keras.src import layers
7+
from keras.src import testing
8+
9+
10+
class RandomColorJitterTest(testing.TestCase):
11+
@pytest.mark.requires_trainable_backend
12+
def test_layer(self):
13+
self.run_layer_test(
14+
layers.RandomColorJitter,
15+
init_kwargs={
16+
"value_range": (20, 200),
17+
"brightness_factor": 0.2,
18+
"contrast_factor": 0.2,
19+
"saturation_factor": 0.2,
20+
"hue_factor": 0.2,
21+
"seed": 1,
22+
},
23+
input_shape=(8, 3, 4, 3),
24+
supports_masking=False,
25+
expected_output_shape=(8, 3, 4, 3),
26+
)
27+
28+
def test_random_color_jitter_inference(self):
29+
seed = 3481
30+
layer = layers.RandomColorJitter(
31+
value_range=(0, 1),
32+
brightness_factor=0.1,
33+
contrast_factor=0.2,
34+
saturation_factor=0.9,
35+
hue_factor=0.1,
36+
)
37+
38+
np.random.seed(seed)
39+
inputs = np.random.randint(0, 255, size=(224, 224, 3))
40+
output = layer(inputs, training=False)
41+
self.assertAllClose(inputs, output)
42+
43+
def test_brightness_only(self):
44+
seed = 2390
45+
np.random.seed(seed)
46+
47+
data_format = backend.config.image_data_format()
48+
if data_format == "channels_last":
49+
inputs = np.random.random((12, 8, 16, 3))
50+
else:
51+
inputs = np.random.random((12, 3, 8, 16))
52+
53+
layer = layers.RandomColorJitter(
54+
brightness_factor=[0.5, 0.5], seed=seed
55+
)
56+
output = backend.convert_to_numpy(layer(inputs))
57+
58+
layer = layers.RandomBrightness(factor=[0.5, 0.5], seed=seed)
59+
sub_output = backend.convert_to_numpy(layer(inputs))
60+
61+
self.assertAllClose(output, sub_output)
62+
63+
def test_saturation_only(self):
64+
seed = 2390
65+
np.random.seed(seed)
66+
67+
data_format = backend.config.image_data_format()
68+
if data_format == "channels_last":
69+
inputs = np.random.random((12, 8, 16, 3))
70+
else:
71+
inputs = np.random.random((12, 3, 8, 16))
72+
73+
layer = layers.RandomColorJitter(
74+
saturation_factor=[0.5, 0.5], seed=seed
75+
)
76+
output = layer(inputs)
77+
78+
layer = layers.RandomSaturation(factor=[0.5, 0.5], seed=seed)
79+
sub_output = layer(inputs)
80+
81+
self.assertAllClose(output, sub_output)
82+
83+
def test_hue_only(self):
84+
seed = 2390
85+
np.random.seed(seed)
86+
87+
data_format = backend.config.image_data_format()
88+
if data_format == "channels_last":
89+
inputs = np.random.random((12, 8, 16, 3))
90+
else:
91+
inputs = np.random.random((12, 3, 8, 16))
92+
93+
layer = layers.RandomColorJitter(hue_factor=[0.5, 0.5], seed=seed)
94+
output = layer(inputs)
95+
96+
layer = layers.RandomHue(factor=[0.5, 0.5], seed=seed)
97+
sub_output = layer(inputs)
98+
99+
self.assertAllClose(output, sub_output)
100+
101+
def test_contrast_only(self):
102+
seed = 2390
103+
np.random.seed(seed)
104+
105+
data_format = backend.config.image_data_format()
106+
if data_format == "channels_last":
107+
inputs = np.random.random((12, 8, 16, 3))
108+
else:
109+
inputs = np.random.random((12, 3, 8, 16))
110+
111+
layer = layers.RandomColorJitter(contrast_factor=[0.5, 0.5], seed=seed)
112+
output = layer(inputs)
113+
114+
layer = layers.RandomContrast(factor=[0.5, 0.5], seed=seed)
115+
sub_output = layer(inputs)
116+
117+
self.assertAllClose(output, sub_output)
118+
119+
def test_tf_data_compatibility(self):
120+
data_format = backend.config.image_data_format()
121+
if data_format == "channels_last":
122+
input_data = np.random.random((2, 8, 8, 3))
123+
else:
124+
input_data = np.random.random((2, 3, 8, 8))
125+
layer = layers.RandomColorJitter(
126+
value_range=(0, 1),
127+
brightness_factor=0.1,
128+
contrast_factor=0.2,
129+
saturation_factor=0.9,
130+
hue_factor=0.1,
131+
)
132+
133+
ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
134+
for output in ds.take(1):
135+
output.numpy()

0 commit comments

Comments
 (0)