@@ -861,3 +861,48 @@ def test_legacy_h5_format(self):
861
861
new_model = keras .saving .load_model (temp_filepath )
862
862
out = new_model (x )
863
863
self .assertAllClose (ref_out , out , atol = 1e-6 )
864
+
865
+ def test_nested_functional_model_saving (self ):
866
+ def func (in_size = 4 , out_size = 2 , name = None ):
867
+ inputs = keras .layers .Input (shape = (in_size ,))
868
+ outputs = keras .layers .Dense (out_size )((inputs ))
869
+ return keras .Model (inputs , outputs = outputs , name = name )
870
+
871
+ input_a , input_b = keras .Input ((4 ,)), keras .Input ((4 ,))
872
+ out_a = func (out_size = 2 , name = "func_a" )(input_a )
873
+ out_b = func (out_size = 3 , name = "func_b" )(input_b )
874
+ model = keras .Model ([input_a , input_b ], outputs = [out_a , out_b ])
875
+
876
+ temp_filepath = os .path .join (self .get_temp_dir (), "nested_func.keras" )
877
+ model .save (temp_filepath )
878
+ new_model = keras .saving .load_model (temp_filepath )
879
+ x = [np .random .random ((2 , 4 ))], np .random .random ((2 , 4 ))
880
+ ref_out = model (x )
881
+ out = new_model (x )
882
+ self .assertAllClose (ref_out [0 ], out [0 ])
883
+ self .assertAllClose (ref_out [1 ], out [1 ])
884
+
885
+ def test_nested_shared_functional_model_saving (self ):
886
+ def func (in_size = 4 , out_size = 2 , name = None ):
887
+ inputs = keras .layers .Input (shape = (in_size ,))
888
+ outputs = keras .layers .Dense (out_size )((inputs ))
889
+ return keras .Model (inputs , outputs = outputs , name = name )
890
+
891
+ inputs = [keras .Input ((4 ,)), keras .Input ((4 ,))]
892
+ func_shared = func (out_size = 4 , name = "func_shared" )
893
+ shared_a = func_shared (inputs [0 ])
894
+ shared_b = func_shared (inputs [1 ])
895
+ out_a = keras .layers .Dense (2 )(shared_a )
896
+ out_b = keras .layers .Dense (2 )(shared_b )
897
+ model = keras .Model (inputs , outputs = [out_a , out_b ])
898
+
899
+ temp_filepath = os .path .join (
900
+ self .get_temp_dir (), "nested_shared_func.keras"
901
+ )
902
+ model .save (temp_filepath )
903
+ new_model = keras .saving .load_model (temp_filepath )
904
+ x = [np .random .random ((2 , 4 ))], np .random .random ((2 , 4 ))
905
+ ref_out = model (x )
906
+ out = new_model (x )
907
+ self .assertAllClose (ref_out [0 ], out [0 ])
908
+ self .assertAllClose (ref_out [1 ], out [1 ])
0 commit comments