Skip to content

Commit 723b4bd

Browse files
author
roebel
committed
Fixed test name and added a further test with custom Model and Conv1D layer.
1 parent 05483ff commit 723b4bd

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

keras/src/layers/reshaping/reshape_test.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
from absl.testing import parameterized
33

4+
from keras.src import Model
45
from keras.src import ops
56
from keras.src import backend
67
from keras.src import layers
@@ -101,7 +102,7 @@ def test_reshape_with_dynamic_batch_size_and_minus_one(self):
101102
reshaped = backend.compute_output_spec(layer.__call__, input)
102103
self.assertEqual(reshaped.shape, (None, 3, 8))
103104

104-
def test_reshape_with_varying_static_batch_size_and_minus_one(self):
105+
def test_reshape_layer_with_varying_input_size_and_minus_one(self):
105106
input = KerasTensor((None, 6, 4))
106107
layer = layers.Reshape((-1, 8))
107108
layer.build(input.shape)
@@ -110,6 +111,23 @@ def test_reshape_with_varying_static_batch_size_and_minus_one(self):
110111
res = layer(ops.ones((1, 10, 4), dtype="float32"))
111112
self.assertEqual(res.shape, (1, 5, 8))
112113

114+
def test_custom_reshape_model_with_varying_input_size_and_minus_one(self):
115+
class MM(layers.Layer):
116+
def __init__(self):
117+
super().__init__()
118+
self.conv = layers.Conv1D(4, 3, padding="same")
119+
self.reshape = layers.Reshape((-1, 8))
120+
121+
def call(self, inputs):
122+
x = self.conv(inputs)
123+
return self.reshape(x)
124+
125+
m = MM()
126+
res = m(ops.ones((1, 6, 2), dtype="float32"))
127+
self.assertEqual(res.shape, (1, 3, 8))
128+
res = m(ops.ones((1, 10, 2), dtype="float32"))
129+
self.assertEqual(res.shape, (1, 5, 8))
130+
113131
def test_reshape_with_dynamic_dim_and_minus_one(self):
114132
input = KerasTensor((4, 6, None, 3))
115133
layer = layers.Reshape((-1, 3))

0 commit comments

Comments
 (0)