Skip to content

Commit 23790c6

Browse files
committed
More detailed test for point inference network serialization
1 parent 50167b3 commit 23790c6

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

tests/test_networks/test_point_inference_network/test_point_inference_network.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ def test_save_and_load_quantile(tmp_path, quantile_point_inference_network, rand
7171
keras.saving.save_model(net, tmp_path / "model.keras")
7272
loaded = keras.saving.load_model(tmp_path / "model.keras")
7373

74+
print(net.get_config())
75+
assert net.get_config() == loaded.get_config()
76+
7477
assert_layers_equal(net, loaded)
7578

7679
for score_key, score in net.scores.items():
@@ -85,5 +88,8 @@ def test_save_and_load_quantile(tmp_path, quantile_point_inference_network, rand
8588
assert keras.ops.all(keras.ops.isclose(net_head.layers[-1].q, loaded_head.layers[-1].q))
8689
assert keras.ops.all(net_head.layers[-1].anchor_index == loaded_head.layers[-1].anchor_index)
8790

91+
print(net_head.get_config())
92+
assert net_head.get_config() == loaded_head.get_config()
93+
8894
print(f"Asserting original and serialized and deserialized at heads[{score_key}][{head_key}] to be equal")
8995
assert_layers_equal(net_head, loaded_head)

0 commit comments

Comments
 (0)