Skip to content

Commit e233825

Browse files
Modified compute_output_shape function to handle broadcasting behavior in layers.Rescaling (#21351)
* Modified compute_output_shape function to handle broadcasting behaviour in layers.Rescaling * Fixed indent and removed tensorflow dependency * Add test case for broadcast output shape * Fix Rescaling layer to handle broadcasting and both data formats correctly
1 parent 79685b6 commit e233825

File tree

2 files changed

+82
-2
lines changed

2 files changed

+82
-2
lines changed

keras/src/layers/preprocessing/rescaling.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import numpy as np
2+
13
from keras.src import backend
24
from keras.src.api_export import keras_export
35
from keras.src.layers.preprocessing.tf_data_layer import TFDataLayer
@@ -27,8 +29,16 @@ class Rescaling(TFDataLayer):
2729
(independently of which backend you're using).
2830
2931
Args:
30-
scale: Float, the scale to apply to the inputs.
31-
offset: Float, the offset to apply to the inputs.
32+
scale: Float, int, list, tuple or np.ndarray.
33+
The scale to apply to the inputs.
34+
If scalar, the same scale will be applied to
35+
all features or channels of input. If a list, tuple or
36+
1D array, the scaling is applied per channel.
37+
offset: Float, int, list/tuple or numpy ndarray.
38+
The offset to apply to the inputs.
39+
If scalar, the same scale will be applied to
40+
all features or channels of input. If a list, tuple or
41+
1D array, the scaling is applied per channel.
3242
**kwargs: Base layer keyword arguments, such as `name` and `dtype`.
3343
"""
3444

@@ -46,13 +56,66 @@ def call(self, inputs):
4656
if (
4757
len(scale_shape) > 0
4858
and backend.image_data_format() == "channels_first"
59+
and len(inputs.shape) > 2
4960
):
5061
scale = self.backend.numpy.reshape(
5162
scale, scale_shape + (1,) * (3 - len(scale_shape))
5263
)
5364
return self.backend.cast(inputs, dtype) * scale + offset
5465

5566
def compute_output_shape(self, input_shape):
67+
input_shape = tuple(input_shape)
68+
69+
if backend.image_data_format() == "channels_last":
70+
channels_axis = -1
71+
else:
72+
channels_axis = 1
73+
74+
input_channels = input_shape[channels_axis]
75+
76+
if input_channels is None:
77+
return input_shape
78+
79+
scale_len = None
80+
offset_len = None
81+
82+
if isinstance(self.scale, (list, tuple)):
83+
scale_len = len(self.scale)
84+
elif isinstance(self.scale, np.ndarray) and self.scale.ndim == 1:
85+
scale_len = self.scale.size
86+
elif isinstance(self.scale, (int, float)):
87+
scale_len = 1
88+
89+
if isinstance(self.offset, (list, tuple)):
90+
offset_len = len(self.offset)
91+
elif isinstance(self.offset, np.ndarray) and self.offset.ndim == 1:
92+
offset_len = self.offset.size
93+
elif isinstance(self.offset, (int, float)):
94+
offset_len = 1
95+
96+
if scale_len == 1 and offset_len == 1:
97+
return input_shape
98+
99+
broadcast_len = None
100+
if scale_len is not None and scale_len != input_channels:
101+
broadcast_len = scale_len
102+
if offset_len is not None and offset_len != input_channels:
103+
if broadcast_len is not None and offset_len != broadcast_len:
104+
raise ValueError(
105+
"Inconsistent `scale` and `offset` lengths "
106+
f"for broadcasting."
107+
f" Received: `scale` = {self.scale},"
108+
f"`offset` = {self.offset}. "
109+
f"Ensure both `scale` and `offset` are either scalar "
110+
f"or list, tuples, arrays of the same length."
111+
)
112+
broadcast_len = offset_len
113+
114+
if broadcast_len:
115+
output_shape = list(input_shape)
116+
output_shape[channels_axis] = broadcast_len
117+
return tuple(output_shape)
118+
56119
return input_shape
57120

58121
def get_config(self):

keras/src/layers/preprocessing/rescaling_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,20 @@ def test_numpy_args(self):
101101
expected_num_losses=0,
102102
supports_masking=True,
103103
)
104+
105+
@pytest.mark.requires_trainable_backend
106+
def test_rescaling_broadcast_output_shape(self):
107+
self.run_layer_test(
108+
layers.Rescaling,
109+
init_kwargs={
110+
"scale": [1.0, 1.0],
111+
"offset": [0.0, 0.0],
112+
},
113+
input_shape=(2, 1),
114+
expected_output_shape=(2, 2),
115+
expected_num_trainable_weights=0,
116+
expected_num_non_trainable_weights=0,
117+
expected_num_seed_generators=0,
118+
expected_num_losses=0,
119+
supports_masking=True,
120+
)

0 commit comments

Comments
 (0)