Skip to content

Commit 4d33a9a

Browse files
Ensures consistent serialization for Functional model (#21341)
* Fixes inconsitent serialization for nested inputs * Adds test case to validate input shape * renames methods
1 parent 3d2db56 commit 4d33a9a

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

keras/src/models/functional.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -453,8 +453,6 @@ def get_tensor_config(tensor):
453453
return [operation.name, new_node_index, tensor_index]
454454

455455
def map_tensors(tensors):
456-
if isinstance(tensors, backend.KerasTensor):
457-
return [get_tensor_config(tensors)]
458456
return tree.map_structure(get_tensor_config, tensors)
459457

460458
config["input_layers"] = map_tensors(self._inputs_struct)
@@ -621,10 +619,6 @@ def map_tensors(tensors):
621619

622620
input_tensors = map_tensors(functional_config["input_layers"])
623621
output_tensors = map_tensors(functional_config["output_layers"])
624-
if isinstance(input_tensors, list) and len(input_tensors) == 1:
625-
input_tensors = input_tensors[0]
626-
if isinstance(output_tensors, list) and len(output_tensors) == 1:
627-
output_tensors = output_tensors[0]
628622

629623
return cls(
630624
inputs=input_tensors,

keras/src/models/functional_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@
1111
from keras.src import ops
1212
from keras.src import saving
1313
from keras.src import testing
14+
from keras.src.backend.common.keras_tensor import KerasTensor
1415
from keras.src.dtype_policies import dtype_policy
1516
from keras.src.layers.core.input_layer import Input
1617
from keras.src.layers.input_spec import InputSpec
1718
from keras.src.models import Functional
1819
from keras.src.models import Model
1920
from keras.src.models import Sequential
21+
from keras.src.models.model import model_from_json
2022

2123

2224
class FunctionalTest(testing.TestCase):
@@ -273,6 +275,27 @@ def test_restored_multi_output_type(self, out_type):
273275
out_val = model_restored(Input(shape=(3,), batch_size=2))
274276
self.assertIsInstance(out_val, out_type)
275277

278+
def test_restored_nested_input(self):
279+
input_a = Input(shape=(3,), batch_size=2, name="input_a")
280+
x = layers.Dense(5)(input_a)
281+
outputs = layers.Dense(4)(x)
282+
model = Functional([[input_a]], outputs)
283+
284+
# Serialize and deserialize the model
285+
json_config = model.to_json()
286+
restored_json_config = model_from_json(json_config).to_json()
287+
288+
# Check that the serialized model is the same as the original
289+
self.assertEqual(json_config, restored_json_config)
290+
291+
def test_functional_input_shape_and_type(self):
292+
input = layers.Input((1024, 4))
293+
conv = layers.Conv1D(32, 3)(input)
294+
model = Functional(input, conv)
295+
296+
self.assertIsInstance(model.input, KerasTensor)
297+
self.assertEqual(model.input_shape, (None, 1024, 4))
298+
276299
@pytest.mark.requires_trainable_backend
277300
def test_layer_getters(self):
278301
# Test mixing ops and layers

0 commit comments

Comments
 (0)