Skip to content

Commit d6ebdce

Browse files
roebelroebel
authored andcommitted
Added reshape_test test case that fails with original implementation and succeeds with fix.
1 parent 49e6646 commit d6ebdce

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

keras/src/layers/reshaping/reshape_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
from absl.testing import parameterized
33

4+
from keras.src import ops
45
from keras.src import backend
56
from keras.src import layers
67
from keras.src import testing
@@ -100,6 +101,16 @@ def test_reshape_with_dynamic_batch_size_and_minus_one(self):
100101
reshaped = backend.compute_output_spec(layer.__call__, input)
101102
self.assertEqual(reshaped.shape, (None, 3, 8))
102103

104+
def test_reshape_with_varying_static_batch_size_and_minus_one(self):
105+
input = KerasTensor((None, 6, 4))
106+
layer = layers.Reshape((-1, 8))
107+
layer.build(input.shape)
108+
layer(ops.ones((1, 6, 4), dtype="float32"))
109+
layer(ops.ones((1, 10, 4), dtype="float32"))
110+
reshaped = backend.compute_output_spec(layer.__call__, input)
111+
self.assertEqual(reshaped.shape, (None, 3, 8))
112+
113+
103114
def test_reshape_with_dynamic_dim_and_minus_one(self):
104115
input = KerasTensor((4, 6, None, 3))
105116
layer = layers.Reshape((-1, 3))

0 commit comments

Comments
 (0)