Skip to content

Commit 7dae3e9

Browse files
committed
Fix issue with deserialization of nested and shared Functional models.
1 parent 032cdff commit 7dae3e9

File tree

3 files changed

+53
-4
lines changed

3 files changed

+53
-4
lines changed

keras/models/functional.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ def process_node(layer, node_data):
480480
layer(*args, **kwargs)
481481

482482
def process_layer(layer_data):
483-
"""Deserializes a layer, then call it on appropriate inputs.
483+
"""Deserializes a layer and index its inbound nodes.
484484
485485
Args:
486486
layer_data: layer config dict.
@@ -553,6 +553,9 @@ def process_layer(layer_data):
553553
def get_tensor(layer_name, node_index, tensor_index):
554554
assert layer_name in created_layers
555555
layer = created_layers[layer_name]
556+
if isinstance(layer, Functional):
557+
# Functional models start out with a built-in node.
558+
node_index -= 1
556559
layer_output_tensors = layer._inbound_nodes[node_index].output_tensors
557560
return layer_output_tensors[tensor_index]
558561

@@ -613,8 +616,9 @@ def serialize_keras_tensor(x):
613616
if isinstance(x, backend.KerasTensor):
614617
operation, node_index, tensor_index = x._keras_history
615618
irrelevant_node_count = 0
616-
for node in operation._inbound_nodes[:node_index]:
617-
if node not in own_nodes:
619+
for i, node in enumerate(operation._inbound_nodes[:node_index]):
620+
node_key = make_node_key(operation, i)
621+
if node_key not in own_nodes:
618622
irrelevant_node_count += 1
619623
x._keras_history = KerasHistory(
620624
operation, node_index - irrelevant_node_count, tensor_index

keras/ops/node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(
8282
self.is_input = not self.arguments.keras_tensors
8383

8484
def __repr__(self):
85-
return f"<Node operation={self.operation}, id={id(self)}>"
85+
return f"<Node operation={self.operation.name}, id={id(self)}>"
8686

8787
@property
8888
def input_tensors(self):

keras/saving/saving_lib_test.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -861,3 +861,48 @@ def test_legacy_h5_format(self):
861861
new_model = keras.saving.load_model(temp_filepath)
862862
out = new_model(x)
863863
self.assertAllClose(ref_out, out, atol=1e-6)
864+
865+
def test_nested_functional_model_saving(self):
866+
def func(in_size=4, out_size=2, name=None):
867+
inputs = keras.layers.Input(shape=(in_size,))
868+
outputs = keras.layers.Dense(out_size)((inputs))
869+
return keras.Model(inputs, outputs=outputs, name=name)
870+
871+
input_a, input_b = keras.Input((4,)), keras.Input((4,))
872+
out_a = func(out_size=2, name="func_a")(input_a)
873+
out_b = func(out_size=3, name="func_b")(input_b)
874+
model = keras.Model([input_a, input_b], outputs=[out_a, out_b])
875+
876+
temp_filepath = os.path.join(self.get_temp_dir(), "nested_func.keras")
877+
model.save(temp_filepath)
878+
new_model = keras.saving.load_model(temp_filepath)
879+
x = [np.random.random((2, 4))], np.random.random((2, 4))
880+
ref_out = model(x)
881+
out = new_model(x)
882+
self.assertAllClose(ref_out[0], out[0])
883+
self.assertAllClose(ref_out[1], out[1])
884+
885+
def test_nested_shared_functional_model_saving(self):
886+
def func(in_size=4, out_size=2, name=None):
887+
inputs = keras.layers.Input(shape=(in_size,))
888+
outputs = keras.layers.Dense(out_size)((inputs))
889+
return keras.Model(inputs, outputs=outputs, name=name)
890+
891+
inputs = [keras.Input((4,)), keras.Input((4,))]
892+
func_shared = func(out_size=4, name="func_shared")
893+
shared_a = func_shared(inputs[0])
894+
shared_b = func_shared(inputs[1])
895+
out_a = keras.layers.Dense(2)(shared_a)
896+
out_b = keras.layers.Dense(2)(shared_b)
897+
model = keras.Model(inputs, outputs=[out_a, out_b])
898+
899+
temp_filepath = os.path.join(
900+
self.get_temp_dir(), "nested_shared_func.keras"
901+
)
902+
model.save(temp_filepath)
903+
new_model = keras.saving.load_model(temp_filepath)
904+
x = [np.random.random((2, 4))], np.random.random((2, 4))
905+
ref_out = model(x)
906+
out = new_model(x)
907+
self.assertAllClose(ref_out[0], out[0])
908+
self.assertAllClose(ref_out[1], out[1])

0 commit comments

Comments
 (0)