Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 32 additions & 12 deletions keras/src/layers/reshaping/reshape.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math

from keras.src import ops
from keras.src.api_export import keras_export
from keras.src.backend.common.keras_tensor import KerasTensor
Expand Down Expand Up @@ -37,7 +39,21 @@ class Reshape(Layer):

def __init__(self, target_shape, **kwargs):
super().__init__(**kwargs)
self.target_shape = tuple(target_shape)
target_shape = tuple(target_shape)
# test validity of target_shape
if target_shape.count(-1) > 1:
raise ValueError(
"The `target_shape` argument must not contain more than one "
"`-1` value. Received: target_shape={}".format(target_shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use an f-string for this.

)
self.target_shape = target_shape
# precalculate all values that might be required
self.need_explicit_shape_for_batch_size_None = (
target_shape.count(-1) == 1
)
self.new_size_no_minus_one = math.prod(
d for d in target_shape if d != -1
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because you removed the build method, add self.built = True at the end of __init__.


def compute_output_shape(self, input_shape):
return (
Expand All @@ -53,18 +69,22 @@ def compute_output_spec(self, inputs):
shape=output_shape, dtype=inputs.dtype, sparse=inputs.sparse
)

def build(self, input_shape):
sample_output_shape = operation_utils.compute_reshape_output_shape(
input_shape[1:], self.target_shape, "target_shape"
)
self._resolved_target_shape = tuple(
-1 if d is None else d for d in sample_output_shape
)

def call(self, inputs):
return ops.reshape(
inputs, (ops.shape(inputs)[0],) + self._resolved_target_shape
)
target_shape = self.target_shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's easier to just transfer the code from build here:

    def call(self, inputs):
        potentially_resolved_target_shape = (
            operation_utils.compute_reshape_output_shape(
                tuple(inputs.shape)[1:], self.target_shape, "target_shape"
            )
        )
        potentially_resolved_target_shape = tuple(
            -1 if d is None else d for d in potentially_resolved_target_shape
        )
        return ops.reshape(
            inputs, (ops.shape(inputs)[0],) + potentially_resolved_target_shape
        )
  • you don't have to reimplement the computation of the missing dimension
  • you don't have to deal with errors if the number of values is not divisible by self.new_size_no_minus_one (right now, that check is missing)
  • you don't need self.need_explicit_shape_for_batch_size_None and self.new_size_no_minus_one

if self.need_explicit_shape_for_batch_size_None and (
inputs.shape[0] is None
):
input_nonbatch_shape = tuple(inputs.shape[1:])
if input_nonbatch_shape.count(None) == 0:
inp_nonbatch_size = math.prod(inputs.shape[1:])
target_shape = tuple(
d
if d != -1
else (inp_nonbatch_size // self.new_size_no_minus_one)
for d in self.target_shape
)

return ops.reshape(inputs, (ops.shape(inputs)[0],) + target_shape)

def get_config(self):
config = {"target_shape": self.target_shape}
Expand Down