1
- import numpy as np
2
-
3
1
from keras .src import backend
4
2
from keras .src .api_export import keras_export
5
3
from keras .src .layers .preprocessing .tf_data_layer import TFDataLayer
@@ -29,16 +27,8 @@ class Rescaling(TFDataLayer):
29
27
(independently of which backend you're using).
30
28
31
29
Args:
32
- scale: Float, int, list, tuple or np.ndarray.
33
- The scale to apply to the inputs.
34
- If scalar, the same scale will be applied to
35
- all features or channels of input. If a list, tuple or
36
- 1D array, the scaling is applied per channel.
37
- offset: Float, int, list/tuple or numpy ndarray.
38
- The offset to apply to the inputs.
39
- If scalar, the same scale will be applied to
40
- all features or channels of input. If a list, tuple or
41
- 1D array, the scaling is applied per channel.
30
+ scale: Float, the scale to apply to the inputs.
31
+ offset: Float, the offset to apply to the inputs.
42
32
**kwargs: Base layer keyword arguments, such as `name` and `dtype`.
43
33
"""
44
34
@@ -56,66 +46,13 @@ def call(self, inputs):
56
46
if (
57
47
len (scale_shape ) > 0
58
48
and backend .image_data_format () == "channels_first"
59
- and len (inputs .shape ) > 2
60
49
):
61
50
scale = self .backend .numpy .reshape (
62
51
scale , scale_shape + (1 ,) * (3 - len (scale_shape ))
63
52
)
64
53
return self .backend .cast (inputs , dtype ) * scale + offset
65
54
66
55
def compute_output_shape (self , input_shape ):
67
- input_shape = tuple (input_shape )
68
-
69
- if backend .image_data_format () == "channels_last" :
70
- channels_axis = - 1
71
- else :
72
- channels_axis = 1
73
-
74
- input_channels = input_shape [channels_axis ]
75
-
76
- if input_channels is None :
77
- return input_shape
78
-
79
- scale_len = None
80
- offset_len = None
81
-
82
- if isinstance (self .scale , (list , tuple )):
83
- scale_len = len (self .scale )
84
- elif isinstance (self .scale , np .ndarray ) and self .scale .ndim == 1 :
85
- scale_len = self .scale .size
86
- elif isinstance (self .scale , (int , float )):
87
- scale_len = 1
88
-
89
- if isinstance (self .offset , (list , tuple )):
90
- offset_len = len (self .offset )
91
- elif isinstance (self .offset , np .ndarray ) and self .offset .ndim == 1 :
92
- offset_len = self .offset .size
93
- elif isinstance (self .offset , (int , float )):
94
- offset_len = 1
95
-
96
- if scale_len == 1 and offset_len == 1 :
97
- return input_shape
98
-
99
- broadcast_len = None
100
- if scale_len is not None and scale_len != input_channels :
101
- broadcast_len = scale_len
102
- if offset_len is not None and offset_len != input_channels :
103
- if broadcast_len is not None and offset_len != broadcast_len :
104
- raise ValueError (
105
- "Inconsistent `scale` and `offset` lengths "
106
- f"for broadcasting."
107
- f" Received: `scale` = { self .scale } ,"
108
- f"`offset` = { self .offset } . "
109
- f"Ensure both `scale` and `offset` are either scalar "
110
- f"or list, tuples, arrays of the same length."
111
- )
112
- broadcast_len = offset_len
113
-
114
- if broadcast_len :
115
- output_shape = list (input_shape )
116
- output_shape [channels_axis ] = broadcast_len
117
- return tuple (output_shape )
118
-
119
56
return input_shape
120
57
121
58
def get_config (self ):
0 commit comments