Skip to content

Commit bc45bcc

Browse files
roebelroebel
authored andcommitted
Proposed fix for issue #21519: Reshape layer does not handle -1 shape infor dynamically.
1 parent ea62750 commit bc45bcc

File tree

1 file changed

+27
-11
lines changed

1 file changed

+27
-11
lines changed

keras/src/layers/reshaping/reshape.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from keras.src.backend.common.keras_tensor import KerasTensor
44
from keras.src.layers.layer import Layer
55
from keras.src.ops import operation_utils
6-
6+
import math
77

88
@keras_export("keras.layers.Reshape")
99
class Reshape(Layer):
@@ -37,7 +37,19 @@ class Reshape(Layer):
3737

3838
def __init__(self, target_shape, **kwargs):
3939
super().__init__(**kwargs)
40-
self.target_shape = tuple(target_shape)
40+
target_shape = tuple(target_shape)
41+
# test validity of target_shape
42+
if target_shape.count(-1) > 1:
43+
raise ValueError(
44+
"The `target_shape` argument must not contain more than one "
45+
"`-1` value. Received: target_shape={}".format(target_shape)
46+
)
47+
self.target_shape = target_shape
48+
# precalculate all values that might be required
49+
self.need_explicit_shape_for_batch_size_None = (target_shape.count(-1) == 1)
50+
self.new_size_no_minus_one = math.prod(
51+
d for d in target_shape if d != -1
52+
)
4153

4254
def compute_output_shape(self, input_shape):
4355
return (
@@ -53,19 +65,23 @@ def compute_output_spec(self, inputs):
5365
shape=output_shape, dtype=inputs.dtype, sparse=inputs.sparse
5466
)
5567

56-
def build(self, input_shape):
57-
sample_output_shape = operation_utils.compute_reshape_output_shape(
58-
input_shape[1:], self.target_shape, "target_shape"
59-
)
60-
self._resolved_target_shape = tuple(
61-
-1 if d is None else d for d in sample_output_shape
62-
)
63-
6468
def call(self, inputs):
69+
target_shape = self.target_shape
70+
if self.need_explicit_shape_for_batch_size_None and (inputs.shape[0] is None):
71+
input_nonbatch_shape = tuple(inputs.shape[1:])
72+
if input_nonbatch_shape.count(None) == 0:
73+
# If the input shape is fully defined, we can compute the desired target_shape
74+
if True:
75+
inp_nonbatch_size = math.prod(inputs.shape[1:])
76+
else:
77+
inp_nonbatch_size = ops.prod(ops.shape(inputs)[1:])
78+
target_shape = tuple(d if d != -1 else (inp_nonbatch_size // self.new_size_no_minus_one) for d in self.target_shape)
79+
6580
return ops.reshape(
66-
inputs, (ops.shape(inputs)[0],) + self._resolved_target_shape
81+
inputs, (ops.shape(inputs)[0],) + target_shape
6782
)
6883

84+
6985
def get_config(self):
7086
config = {"target_shape": self.target_shape}
7187
base_config = super().get_config()

0 commit comments

Comments
 (0)