Skip to content

Commit 5a2ea53

Browse files
committed
Two proposals to include the names of head layers in serialization
1 parent 3d56323 commit 5a2ea53

File tree

2 files changed

+62
-2
lines changed

2 files changed

+62
-2
lines changed

bayesflow/networks/point_inference_network.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
5151
else:
5252
input_shape = conditions_shape
5353

54+
# Save input_shape and xz_shape for usage in get_build_config
55+
self._input_shape = input_shape
56+
self._xz_shape = xz_shape
57+
5458
# build the shared body network
5559
self.subnet.build(input_shape)
5660
body_output_shape = self.subnet.compute_output_shape(input_shape)
@@ -82,6 +86,62 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
8286
flat_key = f"{score_key}___{head_key}"
8387
self.heads_flat[flat_key] = head
8488

89+
def get_build_config(self):
90+
build_config = {
91+
"conditions_shape": self._input_shape,
92+
"xz_shape": self._xz_shape,
93+
}
94+
95+
# Save names of head networks
96+
heads = {}
97+
for score_key in self.heads.keys():
98+
heads[score_key] = {}
99+
for head_key, head in self.heads[score_key].items():
100+
heads[score_key][head_key] = head.name
101+
# Alternatively, save full build config of head
102+
# heads[score_key][head_key] = head.get_build_config()
103+
# TODO: decide
104+
105+
build_config["heads"] = heads
106+
107+
return build_config
108+
109+
def build_from_config(self, config):
110+
self.build(xz_shape=config["xz_shape"], conditions_shape=config["conditions_shape"])
111+
112+
for score_key in self.scores.keys():
113+
for head_key, head in self.heads[score_key].items():
114+
head.name = config["heads"][score_key][head_key]
115+
116+
# Alternatively, do NOT call self.build, but rather imitate it using the build config of each head
117+
# This results in some code duplication with self.build and requires heads to be of a custom type.
118+
# TODO: decide
119+
120+
# input_shape = config["conditions_shape"]
121+
#
122+
# # Save input_shape for usage in get_build_config
123+
# self._input_shape = input_shape
124+
#
125+
# # build the shared body network
126+
# self.subnet.build(input_shape)
127+
#
128+
# # build head(s) for every scoring rule
129+
# self.heads = dict()
130+
# self.heads_flat = dict()
131+
#
132+
# for score_key in self.scores.keys():
133+
#
134+
# self.heads[score_key] = {}
135+
#
136+
# for head_key, head_config in config["heads"][score_key].items():
137+
# head = keras.Sequential()
138+
# head.build_from_config(head_config) # TODO: this method is not implemented yet
139+
# it would require the head to be a
140+
# custom object rather than a Sequential
141+
# self.heads[score_key][head_key] = head
142+
# flat_key = f"{score_key}___{head_key}"
143+
# self.heads_flat[flat_key] = head
144+
85145
def get_config(self):
86146
base_config = super().get_config()
87147

tests/utils/assertions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,5 @@ def assert_layers_equal(layer1: keras.Layer, layer2: keras.Layer):
3232
msg = f"Variable '{v1.name}' for Layer '{layer1.name}' is not equal: {x1} != {x2}"
3333
assert keras.ops.all(keras.ops.isclose(x1, x2)), msg
3434

35-
# The names of layers need not stay the same
36-
# assert layer1.name == layer2.name
35+
msg = f"Layers {layer1.name} and {layer2.name} have a different name."
36+
assert layer1.name == layer2.name, msg

0 commit comments

Comments
 (0)