|
2 | 2 | from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
|
3 | 3 | BaseImagePreprocessingLayer,
|
4 | 4 | )
|
5 |
| -from keras.src.random import SeedGenerator |
6 |
| - |
7 |
| - |
8 |
| -def transform_value_range(images, original_range, target_range): |
9 |
| - if ( |
10 |
| - original_range[0] == target_range[0] |
11 |
| - and original_range[1] == target_range[1] |
12 |
| - ): |
13 |
| - return images |
14 |
| - |
15 |
| - original_min_value, original_max_value = original_range |
16 |
| - target_min_value, target_max_value = target_range |
17 |
| - |
18 |
| - # images in the [0, 1] scale |
19 |
| - images = (images - original_min_value) / ( |
20 |
| - original_max_value - original_min_value |
21 |
| - ) |
22 |
| - |
23 |
| - scale_factor = target_max_value - target_min_value |
24 |
| - return (images * scale_factor) + target_min_value |
25 | 5 |
|
26 | 6 |
|
27 | 7 | @keras_export("keras.layers.RandomHue")
|
@@ -70,7 +50,7 @@ def __init__(
|
70 | 50 | self._set_factor(factor)
|
71 | 51 | self.value_range = value_range
|
72 | 52 | self.seed = seed
|
73 |
| - self.generator = SeedGenerator(seed) |
| 53 | + self.generator = self.backend.random.SeedGenerator(seed) |
74 | 54 |
|
75 | 55 | def get_random_transformation(self, data, training=True, seed=None):
|
76 | 56 | if isinstance(data, dict):
|
@@ -107,7 +87,8 @@ def get_random_transformation(self, data, training=True, seed=None):
|
107 | 87 | return {"factor": invert * factor * 0.5}
|
108 | 88 |
|
109 | 89 | def transform_images(self, images, transformation=None, training=True):
|
110 |
| - images = transform_value_range(images, self.value_range, (0, 1)) |
| 90 | + images = self.backend.cast(images, self.compute_dtype) |
| 91 | + images = self._transform_value_range(images, self.value_range, (0, 1)) |
111 | 92 | adjust_factors = transformation["factor"]
|
112 | 93 | adjust_factors = self.backend.cast(adjust_factors, images.dtype)
|
113 | 94 | adjust_factors = self.backend.numpy.expand_dims(adjust_factors, -1)
|
@@ -144,8 +125,8 @@ def transform_images(self, images, transformation=None, training=True):
|
144 | 125 | )
|
145 | 126 |
|
146 | 127 | images = self.backend.numpy.clip(images, 0, 1)
|
147 |
| - images = transform_value_range(images, (0, 1), self.value_range) |
148 |
| - |
| 128 | + images = self._transform_value_range(images, (0, 1), self.value_range) |
| 129 | + images = self.backend.cast(images, self.compute_dtype) |
149 | 130 | return images
|
150 | 131 |
|
151 | 132 | def transform_labels(self, labels, transformation, training=True):
|
|
0 commit comments