1
1
import pytest
2
2
from absl .testing import parameterized
3
3
4
+ from keras .src import Model
4
5
from keras .src import ops
5
6
from keras .src import backend
6
7
from keras .src import layers
@@ -101,7 +102,7 @@ def test_reshape_with_dynamic_batch_size_and_minus_one(self):
101
102
reshaped = backend .compute_output_spec (layer .__call__ , input )
102
103
self .assertEqual (reshaped .shape , (None , 3 , 8 ))
103
104
104
- def test_reshape_with_varying_static_batch_size_and_minus_one (self ):
105
+ def test_reshape_layer_with_varying_input_size_and_minus_one (self ):
105
106
input = KerasTensor ((None , 6 , 4 ))
106
107
layer = layers .Reshape ((- 1 , 8 ))
107
108
layer .build (input .shape )
@@ -110,6 +111,23 @@ def test_reshape_with_varying_static_batch_size_and_minus_one(self):
110
111
res = layer (ops .ones ((1 , 10 , 4 ), dtype = "float32" ))
111
112
self .assertEqual (res .shape , (1 , 5 , 8 ))
112
113
114
+ def test_custom_reshape_model_with_varying_input_size_and_minus_one (self ):
115
+ class MM (layers .Layer ):
116
+ def __init__ (self ):
117
+ super ().__init__ ()
118
+ self .conv = layers .Conv1D (4 , 3 , padding = "same" )
119
+ self .reshape = layers .Reshape ((- 1 , 8 ))
120
+
121
+ def call (self , inputs ):
122
+ x = self .conv (inputs )
123
+ return self .reshape (x )
124
+
125
+ m = MM ()
126
+ res = m (ops .ones ((1 , 6 , 2 ), dtype = "float32" ))
127
+ self .assertEqual (res .shape , (1 , 3 , 8 ))
128
+ res = m (ops .ones ((1 , 10 , 2 ), dtype = "float32" ))
129
+ self .assertEqual (res .shape , (1 , 5 , 8 ))
130
+
113
131
def test_reshape_with_dynamic_dim_and_minus_one (self ):
114
132
input = KerasTensor ((4 , 6 , None , 3 ))
115
133
layer = layers .Reshape ((- 1 , 3 ))
0 commit comments