Skip to content

Commit 2d96838

Browse files
committed
Fix random hue layer
1 parent 5a9b26d commit 2d96838

File tree

1 file changed

+10
-35
lines changed

1 file changed

+10
-35
lines changed

keras/src/layers/preprocessing/image_preprocessing/random_hue.py

Lines changed: 10 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -59,46 +59,18 @@ class RandomHue(BaseImagePreprocessingLayer):
5959
```
6060
"""
6161

62+
_USE_BASE_FACTOR = True
63+
_FACTOR_BOUNDS = (0, 1)
64+
6265
def __init__(
6366
self, factor, value_range, data_format=None, seed=None, **kwargs
6467
):
6568
super().__init__(data_format=data_format, **kwargs)
66-
self.factor = factor
69+
self._set_factor(factor)
6770
self.value_range = value_range
6871
self.seed = seed
6972
self.generator = SeedGenerator(seed)
7073

71-
def parse_factor(
72-
self, min_value=0.0, max_value=1.0, param_name="factor", shape=None
73-
):
74-
factors = self.factor
75-
if isinstance(factors, float) or isinstance(factors, int):
76-
factors = (min_value, factors)
77-
78-
if factors[0] > factors[1]:
79-
raise ValueError(
80-
f"`{param_name}[0] > {param_name}[1]`, "
81-
f"`{param_name}[0]` must be "
82-
f"<= `{param_name}[1]`. Got `{param_name}={factors}`"
83-
)
84-
if (min_value is not None and factors[0] < min_value) or (
85-
max_value is not None and factors[1] > max_value
86-
):
87-
raise ValueError(
88-
f"`{param_name}` should be inside of range "
89-
f"[{min_value}, {max_value}]. Got {param_name}={factors}"
90-
)
91-
92-
if factors[0] == factors[1]:
93-
return self.backend.numpy.ones(shape=shape) * factors[0]
94-
95-
return self.backend.random.uniform(
96-
shape,
97-
seed=self.generator,
98-
minval=factors[0],
99-
maxval=factors[1],
100-
)
101-
10274
def get_random_transformation(self, data, training=True, seed=None):
10375
if isinstance(data, dict):
10476
images = data["images"]
@@ -125,9 +97,12 @@ def get_random_transformation(self, data, training=True, seed=None):
12597
-self.backend.numpy.ones_like(invert),
12698
self.backend.numpy.ones_like(invert),
12799
)
128-
129-
factor = self.parse_factor(shape=(batch_size,))
130-
100+
factor = self.backend.random.uniform(
101+
(batch_size,),
102+
minval=self.factor[0],
103+
maxval=self.factor[1],
104+
seed=seed,
105+
)
131106
return {"factor": invert * factor * 0.5}
132107

133108
def transform_images(self, images, transformation=None, training=True):

0 commit comments

Comments
 (0)