|
1 | 1 | import pytest
|
2 | 2 | from absl.testing import parameterized
|
3 | 3 |
|
| 4 | +from keras.src import ops |
4 | 5 | from keras.src import backend
|
5 | 6 | from keras.src import layers
|
6 | 7 | from keras.src import testing
|
@@ -100,6 +101,16 @@ def test_reshape_with_dynamic_batch_size_and_minus_one(self):
|
100 | 101 | reshaped = backend.compute_output_spec(layer.__call__, input)
|
101 | 102 | self.assertEqual(reshaped.shape, (None, 3, 8))
|
102 | 103 |
|
| 104 | + def test_reshape_with_varying_static_batch_size_and_minus_one(self): |
| 105 | + input = KerasTensor((None, 6, 4)) |
| 106 | + layer = layers.Reshape((-1, 8)) |
| 107 | + layer.build(input.shape) |
| 108 | + layer(ops.ones((1, 6, 4), dtype="float32")) |
| 109 | + layer(ops.ones((1, 10, 4), dtype="float32")) |
| 110 | + reshaped = backend.compute_output_spec(layer.__call__, input) |
| 111 | + self.assertEqual(reshaped.shape, (None, 3, 8)) |
| 112 | + |
| 113 | + |
103 | 114 | def test_reshape_with_dynamic_dim_and_minus_one(self):
|
104 | 115 | input = KerasTensor((4, 6, None, 3))
|
105 | 116 | layer = layers.Reshape((-1, 3))
|
|
0 commit comments