Skip to content

Commit 49e6646

Browse files
roebelroebel
authored andcommitted
applied changes according to pre-commit hook
1 parent a502050 commit 49e6646

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

keras/src/layers/reshaping/reshape.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import math
2+
13
from keras.src import ops
24
from keras.src.api_export import keras_export
35
from keras.src.backend.common.keras_tensor import KerasTensor
46
from keras.src.layers.layer import Layer
57
from keras.src.ops import operation_utils
6-
import math
8+
79

810
@keras_export("keras.layers.Reshape")
911
class Reshape(Layer):
@@ -46,7 +48,9 @@ def __init__(self, target_shape, **kwargs):
4648
)
4749
self.target_shape = target_shape
4850
# precalculate all values that might be required
49-
self.need_explicit_shape_for_batch_size_None = (target_shape.count(-1) == 1)
51+
self.need_explicit_shape_for_batch_size_None = (
52+
target_shape.count(-1) == 1
53+
)
5054
self.new_size_no_minus_one = math.prod(
5155
d for d in target_shape if d != -1
5256
)
@@ -67,16 +71,20 @@ def compute_output_spec(self, inputs):
6771

6872
def call(self, inputs):
6973
target_shape = self.target_shape
70-
if self.need_explicit_shape_for_batch_size_None and (inputs.shape[0] is None):
74+
if self.need_explicit_shape_for_batch_size_None and (
75+
inputs.shape[0] is None
76+
):
7177
input_nonbatch_shape = tuple(inputs.shape[1:])
7278
if input_nonbatch_shape.count(None) == 0:
7379
inp_nonbatch_size = math.prod(inputs.shape[1:])
74-
target_shape = tuple(d if d != -1 else (inp_nonbatch_size // self.new_size_no_minus_one) for d in self.target_shape)
75-
76-
return ops.reshape(
77-
inputs, (ops.shape(inputs)[0],) + target_shape
78-
)
80+
target_shape = tuple(
81+
d
82+
if d != -1
83+
else (inp_nonbatch_size // self.new_size_no_minus_one)
84+
for d in self.target_shape
85+
)
7986

87+
return ops.reshape(inputs, (ops.shape(inputs)[0],) + target_shape)
8088

8189
def get_config(self):
8290
config = {"target_shape": self.target_shape}

0 commit comments

Comments
 (0)