Skip to content

Proposed fix for issue #21519: Reshape layer does not handle -1 shape… #21568

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
44 changes: 32 additions & 12 deletions keras/src/layers/reshaping/reshape.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math

from keras.src import ops
from keras.src.api_export import keras_export
from keras.src.backend.common.keras_tensor import KerasTensor
Expand Down Expand Up @@ -37,7 +39,21 @@ class Reshape(Layer):

def __init__(self, target_shape, **kwargs):
super().__init__(**kwargs)
self.target_shape = tuple(target_shape)
target_shape = tuple(target_shape)
# test validity of target_shape
if target_shape.count(-1) > 1:
raise ValueError(
"The `target_shape` argument must not contain more than one "
"`-1` value. Received: target_shape={}".format(target_shape)
)
self.target_shape = target_shape
# precalculate all values that might be required
self.need_explicit_shape_for_batch_size_None = (
target_shape.count(-1) == 1
)
self.new_size_no_minus_one = math.prod(
d for d in target_shape if d != -1
)

def compute_output_shape(self, input_shape):
return (
Expand All @@ -53,18 +69,22 @@ def compute_output_spec(self, inputs):
shape=output_shape, dtype=inputs.dtype, sparse=inputs.sparse
)

def build(self, input_shape):
sample_output_shape = operation_utils.compute_reshape_output_shape(
input_shape[1:], self.target_shape, "target_shape"
)
self._resolved_target_shape = tuple(
-1 if d is None else d for d in sample_output_shape
)

def call(self, inputs):
return ops.reshape(
inputs, (ops.shape(inputs)[0],) + self._resolved_target_shape
)
target_shape = self.target_shape
if self.need_explicit_shape_for_batch_size_None and (
inputs.shape[0] is None
):
input_nonbatch_shape = tuple(inputs.shape[1:])
if input_nonbatch_shape.count(None) == 0:
inp_nonbatch_size = math.prod(inputs.shape[1:])
target_shape = tuple(
d
if d != -1
else (inp_nonbatch_size // self.new_size_no_minus_one)
for d in self.target_shape
)

return ops.reshape(inputs, (ops.shape(inputs)[0],) + target_shape)

def get_config(self):
config = {"target_shape": self.target_shape}
Expand Down
28 changes: 28 additions & 0 deletions keras/src/layers/reshaping/reshape_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import pytest
from absl.testing import parameterized

from keras.src import Model
from keras.src import backend
from keras.src import layers
from keras.src import ops
from keras.src import testing
from keras.src.backend.common.keras_tensor import KerasTensor

Expand Down Expand Up @@ -100,6 +102,32 @@ def test_reshape_with_dynamic_batch_size_and_minus_one(self):
reshaped = backend.compute_output_spec(layer.__call__, input)
self.assertEqual(reshaped.shape, (None, 3, 8))

def test_reshape_layer_with_varying_input_size_and_minus_one(self):
input = KerasTensor((None, 6, 4))
layer = layers.Reshape((-1, 8))
layer.build(input.shape)
res = layer(ops.ones((1, 6, 4), dtype="float32"))
self.assertEqual(res.shape, (1, 3, 8))
res = layer(ops.ones((1, 10, 4), dtype="float32"))
self.assertEqual(res.shape, (1, 5, 8))

def test_custom_reshape_model_with_varying_input_size_and_minus_one(self):
class MM(Model):
def __init__(self):
super().__init__()
self.conv = layers.Conv1D(4, 3, padding="same")
self.reshape = layers.Reshape((-1, 8))

def call(self, inputs):
x = self.conv(inputs)
return self.reshape(x)

m = MM()
res = m(ops.ones((1, 6, 2), dtype="float32"))
self.assertEqual(res.shape, (1, 3, 8))
res = m(ops.ones((1, 10, 2), dtype="float32"))
self.assertEqual(res.shape, (1, 5, 8))

def test_reshape_with_dynamic_dim_and_minus_one(self):
input = KerasTensor((4, 6, None, 3))
layer = layers.Reshape((-1, 3))
Expand Down
Loading