1
+ import math
2
+
1
3
from keras .src import ops
2
4
from keras .src .api_export import keras_export
3
5
from keras .src .backend .common .keras_tensor import KerasTensor
4
6
from keras .src .layers .layer import Layer
5
7
from keras .src .ops import operation_utils
6
- import math
8
+
7
9
8
10
@keras_export ("keras.layers.Reshape" )
9
11
class Reshape (Layer ):
@@ -46,7 +48,9 @@ def __init__(self, target_shape, **kwargs):
46
48
)
47
49
self .target_shape = target_shape
48
50
# precalculate all values that might be required
49
- self .need_explicit_shape_for_batch_size_None = (target_shape .count (- 1 ) == 1 )
51
+ self .need_explicit_shape_for_batch_size_None = (
52
+ target_shape .count (- 1 ) == 1
53
+ )
50
54
self .new_size_no_minus_one = math .prod (
51
55
d for d in target_shape if d != - 1
52
56
)
@@ -67,16 +71,20 @@ def compute_output_spec(self, inputs):
67
71
68
72
def call (self , inputs ):
69
73
target_shape = self .target_shape
70
- if self .need_explicit_shape_for_batch_size_None and (inputs .shape [0 ] is None ):
74
+ if self .need_explicit_shape_for_batch_size_None and (
75
+ inputs .shape [0 ] is None
76
+ ):
71
77
input_nonbatch_shape = tuple (inputs .shape [1 :])
72
78
if input_nonbatch_shape .count (None ) == 0 :
73
79
inp_nonbatch_size = math .prod (inputs .shape [1 :])
74
- target_shape = tuple (d if d != - 1 else (inp_nonbatch_size // self .new_size_no_minus_one ) for d in self .target_shape )
75
-
76
- return ops .reshape (
77
- inputs , (ops .shape (inputs )[0 ],) + target_shape
78
- )
80
+ target_shape = tuple (
81
+ d
82
+ if d != - 1
83
+ else (inp_nonbatch_size // self .new_size_no_minus_one )
84
+ for d in self .target_shape
85
+ )
79
86
87
+ return ops .reshape (inputs , (ops .shape (inputs )[0 ],) + target_shape )
80
88
81
89
def get_config (self ):
82
90
config = {"target_shape" : self .target_shape }
0 commit comments