-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Proposed fix for issue #21519: Reshape layer does not handle -1 shape… #21568
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 3 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
bc45bcc
Proposed fix for issue #21519: Reshape layer does not handle -1 shape…
roebel a502050
Removed unused part of the code. The code in the fix should deal excl…
roebel 49e6646
applied changes according to pre-commit hook
roebel d6ebdce
Added reshape_test test case that fails with original implementation …
roebel 05483ff
Added asserts for expected result.
roebel 223ce95
Fixed test name and added a further test with custom Model and Conv1D…
roebel 7225504
applied changes according to pre-commit hook
roebel 1145fb8
Fixed test to use a custom model and not a custom layer as indicated …
roebel d4c8e9d
Implemented suggested changes:
roebel efa8cb4
Implemented suggested changes:
roebel 8c3d239
Remove unused variable.
roebel 8c3c16b
Fixed line lengths in doc string.
roebel b0076a8
Marked test which uses fit method to require a trainable backend.
roebel c63392e
Docs:
roebel fa108c6
Fixed line length.
roebel File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
import math | ||
|
||
from keras.src import ops | ||
from keras.src.api_export import keras_export | ||
from keras.src.backend.common.keras_tensor import KerasTensor | ||
|
@@ -37,7 +39,21 @@ class Reshape(Layer): | |
|
||
def __init__(self, target_shape, **kwargs): | ||
super().__init__(**kwargs) | ||
self.target_shape = tuple(target_shape) | ||
target_shape = tuple(target_shape) | ||
# test validity of target_shape | ||
if target_shape.count(-1) > 1: | ||
raise ValueError( | ||
"The `target_shape` argument must not contain more than one " | ||
"`-1` value. Received: target_shape={}".format(target_shape) | ||
) | ||
self.target_shape = target_shape | ||
# precalculate all values that might be required | ||
self.need_explicit_shape_for_batch_size_None = ( | ||
target_shape.count(-1) == 1 | ||
) | ||
self.new_size_no_minus_one = math.prod( | ||
d for d in target_shape if d != -1 | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because you removed the |
||
|
||
def compute_output_shape(self, input_shape): | ||
return ( | ||
|
@@ -53,18 +69,22 @@ def compute_output_spec(self, inputs): | |
shape=output_shape, dtype=inputs.dtype, sparse=inputs.sparse | ||
) | ||
|
||
def build(self, input_shape): | ||
sample_output_shape = operation_utils.compute_reshape_output_shape( | ||
input_shape[1:], self.target_shape, "target_shape" | ||
) | ||
self._resolved_target_shape = tuple( | ||
-1 if d is None else d for d in sample_output_shape | ||
) | ||
|
||
def call(self, inputs): | ||
return ops.reshape( | ||
inputs, (ops.shape(inputs)[0],) + self._resolved_target_shape | ||
) | ||
target_shape = self.target_shape | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's easier to just transfer the code from def call(self, inputs):
potentially_resolved_target_shape = (
operation_utils.compute_reshape_output_shape(
tuple(inputs.shape)[1:], self.target_shape, "target_shape"
)
)
potentially_resolved_target_shape = tuple(
-1 if d is None else d for d in potentially_resolved_target_shape
)
return ops.reshape(
inputs, (ops.shape(inputs)[0],) + potentially_resolved_target_shape
)
|
||
if self.need_explicit_shape_for_batch_size_None and ( | ||
inputs.shape[0] is None | ||
): | ||
input_nonbatch_shape = tuple(inputs.shape[1:]) | ||
if input_nonbatch_shape.count(None) == 0: | ||
inp_nonbatch_size = math.prod(inputs.shape[1:]) | ||
target_shape = tuple( | ||
d | ||
if d != -1 | ||
else (inp_nonbatch_size // self.new_size_no_minus_one) | ||
for d in self.target_shape | ||
) | ||
|
||
return ops.reshape(inputs, (ops.shape(inputs)[0],) + target_shape) | ||
|
||
def get_config(self): | ||
config = {"target_shape": self.target_shape} | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use an f-string for this.