3
3
from keras .src .backend .common .keras_tensor import KerasTensor
4
4
from keras .src .layers .layer import Layer
5
5
from keras .src .ops import operation_utils
6
-
6
+ import math
7
7
8
8
@keras_export ("keras.layers.Reshape" )
9
9
class Reshape (Layer ):
@@ -37,7 +37,19 @@ class Reshape(Layer):
37
37
38
38
def __init__ (self , target_shape , ** kwargs ):
39
39
super ().__init__ (** kwargs )
40
- self .target_shape = tuple (target_shape )
40
+ target_shape = tuple (target_shape )
41
+ # test validity of target_shape
42
+ if target_shape .count (- 1 ) > 1 :
43
+ raise ValueError (
44
+ "The `target_shape` argument must not contain more than one "
45
+ "`-1` value. Received: target_shape={}" .format (target_shape )
46
+ )
47
+ self .target_shape = target_shape
48
+ # precalculate all values that might be required
49
+ self .need_explicit_shape_for_batch_size_None = (target_shape .count (- 1 ) == 1 )
50
+ self .new_size_no_minus_one = math .prod (
51
+ d for d in target_shape if d != - 1
52
+ )
41
53
42
54
def compute_output_shape (self , input_shape ):
43
55
return (
@@ -53,19 +65,23 @@ def compute_output_spec(self, inputs):
53
65
shape = output_shape , dtype = inputs .dtype , sparse = inputs .sparse
54
66
)
55
67
56
- def build (self , input_shape ):
57
- sample_output_shape = operation_utils .compute_reshape_output_shape (
58
- input_shape [1 :], self .target_shape , "target_shape"
59
- )
60
- self ._resolved_target_shape = tuple (
61
- - 1 if d is None else d for d in sample_output_shape
62
- )
63
-
64
68
def call (self , inputs ):
69
+ target_shape = self .target_shape
70
+ if self .need_explicit_shape_for_batch_size_None and (inputs .shape [0 ] is None ):
71
+ input_nonbatch_shape = tuple (inputs .shape [1 :])
72
+ if input_nonbatch_shape .count (None ) == 0 :
73
+ # If the input shape is fully defined, we can compute the desired target_shape
74
+ if True :
75
+ inp_nonbatch_size = math .prod (inputs .shape [1 :])
76
+ else :
77
+ inp_nonbatch_size = ops .prod (ops .shape (inputs )[1 :])
78
+ target_shape = tuple (d if d != - 1 else (inp_nonbatch_size // self .new_size_no_minus_one ) for d in self .target_shape )
79
+
65
80
return ops .reshape (
66
- inputs , (ops .shape (inputs )[0 ],) + self . _resolved_target_shape
81
+ inputs , (ops .shape (inputs )[0 ],) + target_shape
67
82
)
68
83
84
+
69
85
def get_config (self ):
70
86
config = {"target_shape" : self .target_shape }
71
87
base_config = super ().get_config ()
0 commit comments