1
+ import numpy as np
2
+
1
3
from keras .src import backend
2
4
from keras .src .api_export import keras_export
3
5
from keras .src .layers .preprocessing .tf_data_layer import TFDataLayer
@@ -27,8 +29,16 @@ class Rescaling(TFDataLayer):
27
29
(independently of which backend you're using).
28
30
29
31
Args:
30
- scale: Float, the scale to apply to the inputs.
31
- offset: Float, the offset to apply to the inputs.
32
+ scale: Float, int, list, tuple or np.ndarray.
33
+ The scale to apply to the inputs.
34
+ If scalar, the same scale will be applied to
35
+ all features or channels of input. If a list, tuple or
36
+ 1D array, the scaling is applied per channel.
37
+ offset: Float, int, list/tuple or numpy ndarray.
38
+ The offset to apply to the inputs.
39
+ If scalar, the same scale will be applied to
40
+ all features or channels of input. If a list, tuple or
41
+ 1D array, the scaling is applied per channel.
32
42
**kwargs: Base layer keyword arguments, such as `name` and `dtype`.
33
43
"""
34
44
@@ -46,13 +56,66 @@ def call(self, inputs):
46
56
if (
47
57
len (scale_shape ) > 0
48
58
and backend .image_data_format () == "channels_first"
59
+ and len (inputs .shape ) > 2
49
60
):
50
61
scale = self .backend .numpy .reshape (
51
62
scale , scale_shape + (1 ,) * (3 - len (scale_shape ))
52
63
)
53
64
return self .backend .cast (inputs , dtype ) * scale + offset
54
65
55
66
def compute_output_shape (self , input_shape ):
67
+ input_shape = tuple (input_shape )
68
+
69
+ if backend .image_data_format () == "channels_last" :
70
+ channels_axis = - 1
71
+ else :
72
+ channels_axis = 1
73
+
74
+ input_channels = input_shape [channels_axis ]
75
+
76
+ if input_channels is None :
77
+ return input_shape
78
+
79
+ scale_len = None
80
+ offset_len = None
81
+
82
+ if isinstance (self .scale , (list , tuple )):
83
+ scale_len = len (self .scale )
84
+ elif isinstance (self .scale , np .ndarray ) and self .scale .ndim == 1 :
85
+ scale_len = self .scale .size
86
+ elif isinstance (self .scale , (int , float )):
87
+ scale_len = 1
88
+
89
+ if isinstance (self .offset , (list , tuple )):
90
+ offset_len = len (self .offset )
91
+ elif isinstance (self .offset , np .ndarray ) and self .offset .ndim == 1 :
92
+ offset_len = self .offset .size
93
+ elif isinstance (self .offset , (int , float )):
94
+ offset_len = 1
95
+
96
+ if scale_len == 1 and offset_len == 1 :
97
+ return input_shape
98
+
99
+ broadcast_len = None
100
+ if scale_len is not None and scale_len != input_channels :
101
+ broadcast_len = scale_len
102
+ if offset_len is not None and offset_len != input_channels :
103
+ if broadcast_len is not None and offset_len != broadcast_len :
104
+ raise ValueError (
105
+ "Inconsistent `scale` and `offset` lengths "
106
+ f"for broadcasting."
107
+ f" Received: `scale` = { self .scale } ,"
108
+ f"`offset` = { self .offset } . "
109
+ f"Ensure both `scale` and `offset` are either scalar "
110
+ f"or list, tuples, arrays of the same length."
111
+ )
112
+ broadcast_len = offset_len
113
+
114
+ if broadcast_len :
115
+ output_shape = list (input_shape )
116
+ output_shape [channels_axis ] = broadcast_len
117
+ return tuple (output_shape )
118
+
56
119
return input_shape
57
120
58
121
def get_config (self ):
0 commit comments