Skip to content

Commit d8afc05

Browse files
authored
Fix Randomhue (#20652)
* Small fix in random hue * use self.backend for seed
1 parent 4c05e0c commit d8afc05

File tree

1 file changed

+5
-24
lines changed

1 file changed

+5
-24
lines changed

keras/src/layers/preprocessing/image_preprocessing/random_hue.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,6 @@
22
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
33
BaseImagePreprocessingLayer,
44
)
5-
from keras.src.random import SeedGenerator
6-
7-
8-
def transform_value_range(images, original_range, target_range):
9-
if (
10-
original_range[0] == target_range[0]
11-
and original_range[1] == target_range[1]
12-
):
13-
return images
14-
15-
original_min_value, original_max_value = original_range
16-
target_min_value, target_max_value = target_range
17-
18-
# images in the [0, 1] scale
19-
images = (images - original_min_value) / (
20-
original_max_value - original_min_value
21-
)
22-
23-
scale_factor = target_max_value - target_min_value
24-
return (images * scale_factor) + target_min_value
255

266

277
@keras_export("keras.layers.RandomHue")
@@ -70,7 +50,7 @@ def __init__(
7050
self._set_factor(factor)
7151
self.value_range = value_range
7252
self.seed = seed
73-
self.generator = SeedGenerator(seed)
53+
self.generator = self.backend.random.SeedGenerator(seed)
7454

7555
def get_random_transformation(self, data, training=True, seed=None):
7656
if isinstance(data, dict):
@@ -107,7 +87,8 @@ def get_random_transformation(self, data, training=True, seed=None):
10787
return {"factor": invert * factor * 0.5}
10888

10989
def transform_images(self, images, transformation=None, training=True):
110-
images = transform_value_range(images, self.value_range, (0, 1))
90+
images = self.backend.cast(images, self.compute_dtype)
91+
images = self._transform_value_range(images, self.value_range, (0, 1))
11192
adjust_factors = transformation["factor"]
11293
adjust_factors = self.backend.cast(adjust_factors, images.dtype)
11394
adjust_factors = self.backend.numpy.expand_dims(adjust_factors, -1)
@@ -144,8 +125,8 @@ def transform_images(self, images, transformation=None, training=True):
144125
)
145126

146127
images = self.backend.numpy.clip(images, 0, 1)
147-
images = transform_value_range(images, (0, 1), self.value_range)
148-
128+
images = self._transform_value_range(images, (0, 1), self.value_range)
129+
images = self.backend.cast(images, self.compute_dtype)
149130
return images
150131

151132
def transform_labels(self, labels, transformation, training=True):

0 commit comments

Comments
 (0)