|
11 | 11 | from keras.src import ops
|
12 | 12 | from keras.src import saving
|
13 | 13 | from keras.src import testing
|
| 14 | +from keras.src.backend.common.keras_tensor import KerasTensor |
14 | 15 | from keras.src.dtype_policies import dtype_policy
|
15 | 16 | from keras.src.layers.core.input_layer import Input
|
16 | 17 | from keras.src.layers.input_spec import InputSpec
|
17 | 18 | from keras.src.models import Functional
|
18 | 19 | from keras.src.models import Model
|
19 | 20 | from keras.src.models import Sequential
|
| 21 | +from keras.src.models.model import model_from_json |
20 | 22 |
|
21 | 23 |
|
22 | 24 | class FunctionalTest(testing.TestCase):
|
@@ -273,6 +275,27 @@ def test_restored_multi_output_type(self, out_type):
|
273 | 275 | out_val = model_restored(Input(shape=(3,), batch_size=2))
|
274 | 276 | self.assertIsInstance(out_val, out_type)
|
275 | 277 |
|
| 278 | + def test_restored_nested_input(self): |
| 279 | + input_a = Input(shape=(3,), batch_size=2, name="input_a") |
| 280 | + x = layers.Dense(5)(input_a) |
| 281 | + outputs = layers.Dense(4)(x) |
| 282 | + model = Functional([[input_a]], outputs) |
| 283 | + |
| 284 | + # Serialize and deserialize the model |
| 285 | + json_config = model.to_json() |
| 286 | + restored_json_config = model_from_json(json_config).to_json() |
| 287 | + |
| 288 | + # Check that the serialized model is the same as the original |
| 289 | + self.assertEqual(json_config, restored_json_config) |
| 290 | + |
| 291 | + def test_functional_input_shape_and_type(self): |
| 292 | + input = layers.Input((1024, 4)) |
| 293 | + conv = layers.Conv1D(32, 3)(input) |
| 294 | + model = Functional(input, conv) |
| 295 | + |
| 296 | + self.assertIsInstance(model.input, KerasTensor) |
| 297 | + self.assertEqual(model.input_shape, (None, 1024, 4)) |
| 298 | + |
276 | 299 | @pytest.mark.requires_trainable_backend
|
277 | 300 | def test_layer_getters(self):
|
278 | 301 | # Test mixing ops and layers
|
|
0 commit comments