Skip to content

Commit c2ca810

Browse files
committed
Refactor to remove set_head_shapes_from_target_shape
1 parent aa20868 commit c2ca810

File tree

4 files changed

+30
-19
lines changed

4 files changed

+30
-19
lines changed

bayesflow/networks/point_inference_network.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,12 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
5252
self.heads_flat = dict() # see comment regarding heads_flat below
5353

5454
for score_key, score in self.scores.items():
55-
score.set_head_shapes_from_target_shape(xz_shape)
55+
head_shapes = score.get_head_shapes_from_target_shape(xz_shape)
5656

5757
self.heads[score_key] = {}
5858

59-
for head_key in score.head_shapes.keys():
60-
head = score.get_head(head_key)
59+
for head_key, head_shape in head_shapes.items():
60+
head = score.get_head(head_key, head_shape)
6161
head.build(body_output_shape)
6262
# If head is not tracked explicitly, self.variables does not include them.
6363
# Testing with tests.utils.assert_layers_equal() would thus neglect heads

bayesflow/scores/scores.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,20 @@ def from_config(cls, config):
6060
def get_head_shapes_from_target_shape(self, target_shape):
6161
raise NotImplementedError
6262

63-
def set_head_shapes_from_target_shape(self, target_shape):
64-
self.head_shapes = self.get_head_shapes_from_target_shape(target_shape)
65-
66-
def get_subnet(self, key: str):
63+
def get_subnet(self, key: str) -> keras.Layer:
64+
"""For a specified key, request a subnet to be used for projecting the shared condition embedding
65+
before reshaping to the heads output shape.
66+
67+
Parameters
68+
----------
69+
key : str
70+
Name of head for which to request a link.
71+
72+
Returns
73+
-------
74+
link : keras.Layer
75+
Subnet projecting the shared condition embedding.
76+
"""
6777
if key not in self.subnets.keys():
6878
return keras.layers.Identity()
6979
else:
@@ -77,15 +87,15 @@ def get_link(self, key: str):
7787
else:
7888
return self.links[key]
7989

80-
def get_head(self, key: str):
90+
def get_head(self, key: str, shape: Shape):
8191
subnet = self.get_subnet(key)
82-
head_shape = self.head_shapes[key]
83-
dense = keras.layers.Dense(units=math.prod(head_shape))
84-
reshape = keras.layers.Reshape(target_shape=head_shape)
92+
dense = keras.layers.Dense(units=math.prod(shape))
93+
reshape = keras.layers.Reshape(target_shape=shape)
8594
link = self.get_link(key)
8695
return keras.Sequential([subnet, dense, reshape, link])
8796

8897
def score(self, estimates: dict[str, Tensor], targets: Tensor, weights: Tensor) -> Tensor:
98+
"""Scores a probabilistic estimate based of a distribution based on samples of that distribution."""
8999
raise NotImplementedError
90100

91101
def aggregate(self, scores: Tensor, weights: Tensor = None):
@@ -114,7 +124,7 @@ def __init__(
114124
"k": k,
115125
}
116126

117-
def get_head_shapes_from_target_shape(self, target_shape):
127+
def get_head_shapes_from_target_shape(self, target_shape: Shape):
118128
# keras.saving.load_model sometimes passes target_shape as a list.
119129
# This is why I force a conversion to tuple here.
120130
target_shape = tuple(target_shape)
@@ -180,7 +190,7 @@ def get_config(self):
180190
base_config = super().get_config()
181191
return base_config | self.config
182192

183-
def get_head_shapes_from_target_shape(self, target_shape):
193+
def get_head_shapes_from_target_shape(self, target_shape: Shape):
184194
# keras.saving.load_model sometimes passes target_shape as a list.
185195
# This is why I force a conversion to tuple here.
186196
target_shape = tuple(target_shape)
@@ -240,7 +250,7 @@ def get_config(self):
240250
base_config = super().get_config()
241251
return base_config | self.config
242252

243-
def get_head_shapes_from_target_shape(self, target_shape) -> dict[str, Shape]:
253+
def get_head_shapes_from_target_shape(self, target_shape: Shape) -> dict[str, Shape]:
244254
self.D = target_shape[-1]
245255
return dict(
246256
mean=(self.D,),

tests/test_networks/test_point_inference_network/test_point_inference_network.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@ def test_output_structure(point_inference_network, random_samples, random_condit
1212

1313
assert isinstance(output, dict)
1414
for score_key, score in point_inference_network.scores.items():
15-
assert isinstance(score.head_shapes, dict)
15+
head_shapes = score.get_head_shapes_from_target_shape(random_samples.shape)
16+
assert isinstance(head_shapes, dict)
1617

17-
for head_key, head_shape in score.head_shapes.items():
18+
for head_key, head_shape in head_shapes.items():
1819
head_output = output[score_key][head_key]
1920
assert keras.ops.is_tensor(head_output)
2021
assert head_output.shape[1:] == head_shape

tests/test_scores/test_scores.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@ def test_score_output(scoring_rule, random_conditions):
1515
if random_conditions is None:
1616
random_conditions = keras.ops.convert_to_tensor([[1.0]])
1717

18-
scoring_rule.set_head_shapes_from_target_shape(random_conditions.shape)
18+
# Using random random_conditions also as targets for the purpose of this test.
19+
head_shapes = scoring_rule.get_head_shapes_from_target_shape(random_conditions.shape)
1920
print(scoring_rule.get_config())
2021
estimates = {
2122
k: scoring_rule.get_link(k)(keras.random.normal((random_conditions.shape[0],) + head_shape))
22-
for k, head_shape in scoring_rule.head_shapes.items()
23+
for k, head_shape in head_shapes.items()
2324
}
2425
score = scoring_rule.score(estimates, random_conditions)
2526

@@ -30,7 +31,6 @@ def test_mean_score_optimality(mean_score, random_conditions):
3031
if random_conditions is None:
3132
random_conditions = keras.ops.convert_to_tensor([[1.0]])
3233

33-
mean_score.set_head_shapes_from_target_shape(random_conditions.shape)
3434
key = "value"
3535
suboptimal_estimates = {key: keras.random.uniform(random_conditions.shape)}
3636
optimal_estimates = {key: random_conditions}

0 commit comments

Comments
 (0)