Skip to content

Commit e8771ad

Browse files
committed
Revert "Modified compute_output_shape function to handle broadcasting behavior in layers.Rescaling (#21351)"
This reverts commit e233825.
1 parent 0589a1c commit e8771ad

File tree

2 files changed

+2
-82
lines changed

2 files changed

+2
-82
lines changed

keras/src/layers/preprocessing/rescaling.py

Lines changed: 2 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import numpy as np
2-
31
from keras.src import backend
42
from keras.src.api_export import keras_export
53
from keras.src.layers.preprocessing.tf_data_layer import TFDataLayer
@@ -29,16 +27,8 @@ class Rescaling(TFDataLayer):
2927
(independently of which backend you're using).
3028
3129
Args:
32-
scale: Float, int, list, tuple or np.ndarray.
33-
The scale to apply to the inputs.
34-
If scalar, the same scale will be applied to
35-
all features or channels of input. If a list, tuple or
36-
1D array, the scaling is applied per channel.
37-
offset: Float, int, list/tuple or numpy ndarray.
38-
The offset to apply to the inputs.
39-
If scalar, the same scale will be applied to
40-
all features or channels of input. If a list, tuple or
41-
1D array, the scaling is applied per channel.
30+
scale: Float, the scale to apply to the inputs.
31+
offset: Float, the offset to apply to the inputs.
4232
**kwargs: Base layer keyword arguments, such as `name` and `dtype`.
4333
"""
4434

@@ -56,66 +46,13 @@ def call(self, inputs):
5646
if (
5747
len(scale_shape) > 0
5848
and backend.image_data_format() == "channels_first"
59-
and len(inputs.shape) > 2
6049
):
6150
scale = self.backend.numpy.reshape(
6251
scale, scale_shape + (1,) * (3 - len(scale_shape))
6352
)
6453
return self.backend.cast(inputs, dtype) * scale + offset
6554

6655
def compute_output_shape(self, input_shape):
67-
input_shape = tuple(input_shape)
68-
69-
if backend.image_data_format() == "channels_last":
70-
channels_axis = -1
71-
else:
72-
channels_axis = 1
73-
74-
input_channels = input_shape[channels_axis]
75-
76-
if input_channels is None:
77-
return input_shape
78-
79-
scale_len = None
80-
offset_len = None
81-
82-
if isinstance(self.scale, (list, tuple)):
83-
scale_len = len(self.scale)
84-
elif isinstance(self.scale, np.ndarray) and self.scale.ndim == 1:
85-
scale_len = self.scale.size
86-
elif isinstance(self.scale, (int, float)):
87-
scale_len = 1
88-
89-
if isinstance(self.offset, (list, tuple)):
90-
offset_len = len(self.offset)
91-
elif isinstance(self.offset, np.ndarray) and self.offset.ndim == 1:
92-
offset_len = self.offset.size
93-
elif isinstance(self.offset, (int, float)):
94-
offset_len = 1
95-
96-
if scale_len == 1 and offset_len == 1:
97-
return input_shape
98-
99-
broadcast_len = None
100-
if scale_len is not None and scale_len != input_channels:
101-
broadcast_len = scale_len
102-
if offset_len is not None and offset_len != input_channels:
103-
if broadcast_len is not None and offset_len != broadcast_len:
104-
raise ValueError(
105-
"Inconsistent `scale` and `offset` lengths "
106-
f"for broadcasting."
107-
f" Received: `scale` = {self.scale},"
108-
f"`offset` = {self.offset}. "
109-
f"Ensure both `scale` and `offset` are either scalar "
110-
f"or list, tuples, arrays of the same length."
111-
)
112-
broadcast_len = offset_len
113-
114-
if broadcast_len:
115-
output_shape = list(input_shape)
116-
output_shape[channels_axis] = broadcast_len
117-
return tuple(output_shape)
118-
11956
return input_shape
12057

12158
def get_config(self):

keras/src/layers/preprocessing/rescaling_test.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -117,20 +117,3 @@ def test_numpy_args(self):
117117
expected_num_losses=0,
118118
supports_masking=True,
119119
)
120-
121-
@pytest.mark.requires_trainable_backend
122-
def test_rescaling_broadcast_output_shape(self):
123-
self.run_layer_test(
124-
layers.Rescaling,
125-
init_kwargs={
126-
"scale": [1.0, 1.0],
127-
"offset": [0.0, 0.0],
128-
},
129-
input_shape=(2, 1),
130-
expected_output_shape=(2, 2),
131-
expected_num_trainable_weights=0,
132-
expected_num_non_trainable_weights=0,
133-
expected_num_seed_generators=0,
134-
expected_num_losses=0,
135-
supports_masking=True,
136-
)

0 commit comments

Comments
 (0)