Skip to content

Commit 0cd3110

Browse files
committed
automatic head building for multiple scoring rules
1 parent 60050ec commit 0cd3110

File tree

6 files changed

+185
-9
lines changed

6 files changed

+185
-9
lines changed

bayesflow/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
)
1313

1414
from .workflows import BasicWorkflow
15-
from .approximators import ContinuousApproximator, ContinuousPointApproximator
15+
from .approximators import ContinuousApproximator, PointApproximator
1616
from .adapters import Adapter
1717
from .datasets import OfflineDataset, OnlineDataset, DiskDataset
1818
from .simulators import make_simulator
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .approximator import Approximator
22
from .continuous_approximator import ContinuousApproximator
3-
from .continuous_point_approximator import ContinuousPointApproximator
3+
from .point_approximator import PointApproximator
44
from .model_comparison_approximator import ModelComparisonApproximator

bayesflow/approximators/continuous_point_approximator.py renamed to bayesflow/approximators/point_approximator.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717

1818
@serializable(package="bayesflow.approximators")
19-
class ContinuousPointApproximator(Approximator):
19+
class PointApproximator(Approximator):
2020
"""
2121
Defines a workflow for performing fast posterior or likelihood inference.
2222
The distribution is approximated by a point with an feed-forward network and an optional summary network.
@@ -142,7 +142,12 @@ def estimate(
142142
conditions = keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
143143
conditions = {"inference_variables": self._estimate(**conditions)}
144144
conditions = keras.tree.map_structure(keras.ops.convert_to_numpy, conditions)
145-
conditions = self.adapter(conditions, inverse=True, strict=False, **kwargs)
145+
conditions["inference_variables"] = {
146+
key: self.adapter(
147+
dict(inference_variables=conditions["inference_variables"][key]), inverse=True, strict=False, **kwargs
148+
)
149+
for key in conditions["inference_variables"].keys()
150+
}
146151

147152
if split:
148153
conditions = split_arrays(conditions, axis=-1)
Lines changed: 98 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,89 @@
11
import keras
22

3+
from math import prod
4+
5+
from collections.abc import Callable
6+
7+
from bayesflow.utils import keras_kwargs, find_network
38
from bayesflow.types import Shape, Tensor
9+
from bayesflow.scoring_rules import ScoringRule
10+
11+
# TODO:
12+
# * [ ] weight initialization
13+
# * [ ] serializable ?
14+
# * [ ] testing
15+
# * [ ] docstrings
416

517

618
class PointInferenceNetwork(keras.Layer):
7-
def __init__(self, **kwargs):
8-
super().__init__(**kwargs)
19+
def __init__(
20+
self,
21+
scoring_rules: dict[str, ScoringRule],
22+
body_subnet: str | type = "mlp", # naming: shared_subnet / body / subnet ?
23+
heads_subnet: dict[str, str | keras.Layer] = None, # TODO: `type` instead of `keras.Layer` ? Too specific ?
24+
activations: dict[str, keras.layers.Activation | Callable | str] = None,
25+
**kwargs,
26+
):
27+
super().__init__(
28+
**keras_kwargs(kwargs)
29+
) # TODO: need for bf.utils.keras_kwargs in regular InferenceNetwork class? seems to be a bug
30+
31+
self.scoring_rules = scoring_rules
32+
# For now PointInferenceNetwork uses the same scoring rules for all parameters
33+
# To support using different sets of scoring rules for different parameter (blocks),
34+
# we can look into renaming this class to sth like `HeadCollection` and
35+
# handle the split in a higher-level object. (PointApproximator?)
36+
37+
self.body_subnet = find_network(body_subnet, **kwargs.get("body_subnet_kwargs", {}))
38+
39+
if heads_subnet:
40+
self.heads = {
41+
key: [find_network(value, **kwargs.get("heads_subnet_kwargs", {}).get(key, {}))]
42+
for key, value in heads_subnet.items()
43+
}
44+
else:
45+
self.heads = {key: [] for key in self.scoring_rules.keys()}
46+
47+
if activations:
48+
self.activations = {
49+
key: (value if isinstance(value, keras.layers.Activation) else keras.layers.Activation(value))
50+
for key, value in activations.items()
51+
} # make sure that each value is an Activation object
52+
else:
53+
self.activations = {key: keras.layers.Activation("linear") for key in self.scoring_rules.keys()}
54+
# TODO: Stefan suggested to call these link functions, decide on this
55+
56+
for key in self.heads.keys():
57+
self.heads[key] += [
58+
keras.layers.Dense(units=None),
59+
keras.layers.Reshape(target_shape=(None,)),
60+
self.activations[key],
61+
]
62+
63+
# TODO: allow key-wise overriding of the default, instead of just complete default or totally custom choices
64+
65+
assert set(self.scoring_rules.keys()) == set(self.heads.keys()) == set(self.activations.keys())
966

1067
def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
11-
pass
68+
# build the shared body network
69+
input_shape = conditions_shape
70+
self.body_subnet.build(input_shape)
71+
body_output_shape = self.body_subnet.compute_output_shape(input_shape)
72+
73+
for key in self.heads.keys():
74+
# head_output_shape (excluding batch_size) convention is (*prediction_shape, *parameter_block_shape)
75+
prediction_shape = self.scoring_rules[key].prediction_shape
76+
head_output_shape = prediction_shape + xz_shape[1:]
77+
78+
# set correct head shape
79+
self.heads[key][-3].units = prod(head_output_shape)
80+
self.heads[key][-2].target_shape = head_output_shape
81+
82+
# build head block by block
83+
input_shape = body_output_shape
84+
for head_block in self.heads[key]:
85+
head_block.build(input_shape)
86+
input_shape = head_block.compute_output_shape(input_shape)
1287

1388
def call(
1489
self,
@@ -17,19 +92,37 @@ def call(
1792
training: bool = False,
1893
**kwargs,
1994
) -> Tensor | tuple[Tensor, Tensor]:
95+
# TODO: remove unnecessary simularity with InferenceNetwork
2096
return self._forward(xz, conditions=conditions, training=training, **kwargs)
2197

2298
def _forward(
2399
self, x: Tensor, conditions: Tensor = None, training: bool = False, **kwargs
24100
) -> Tensor | tuple[Tensor, Tensor]:
25-
raise NotImplementedError
101+
body_output = self.body_subnet(conditions)
102+
103+
output = dict()
104+
for key, head in self.heads.items():
105+
y = body_output
106+
for head_block in head:
107+
y = head_block(y)
108+
109+
output |= {key: y}
110+
return output
26111

27112
def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]:
28113
if not self.built:
29114
xz_shape = keras.ops.shape(x)
30115
conditions_shape = None if conditions is None else keras.ops.shape(conditions)
31116
self.build(xz_shape, conditions_shape=conditions_shape)
32117

118+
output = self(x, conditions)
119+
120+
# calculate negative score as mean over all heads
121+
neg_score = 0
122+
for key, rule in self.scoring_rules.items():
123+
neg_score += rule.score(output[key], x)
124+
neg_score /= len(self.scoring_rules)
125+
33126
metrics = {}
34127

35128
if stage != "training" and any(self.metrics):
@@ -41,7 +134,7 @@ def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "tr
41134
pass
42135
# TODO: instead compute estimate based metrics
43136

44-
return metrics
137+
return metrics | {"loss": neg_score}
45138

46139
def estimate(self, conditions: Tensor = None) -> Tensor:
47140
return self._forward(None, conditions)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .scoring_rules import ScoringRule, NormedDifferenceLoss, QuantileLoss
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from collections.abc import Callable, Sequence
2+
3+
from bayesflow.types import Tensor
4+
5+
import keras
6+
7+
8+
class ScoringRule:
9+
def __init__(
10+
self,
11+
name: str = None,
12+
):
13+
self.name = name # TODO: names for scoring rules may be unnecessary ?
14+
15+
def score(self, target, reference):
16+
raise NotImplementedError
17+
18+
19+
class NormedDifferenceLoss(ScoringRule):
20+
def __init__(
21+
self,
22+
k: int = 2, # results in an estimator for the mean
23+
name: str = "normed_difference",
24+
):
25+
super().__init__(name)
26+
27+
self.k = k
28+
self.target_shape = (1,)
29+
30+
def score(self, target: Tensor, reference: Tensor) -> Tensor:
31+
pointwise_differance = target - reference[:, None, :]
32+
score = keras.ops.absolute(pointwise_differance) ** self.k
33+
score = keras.ops.mean(score)
34+
return score
35+
36+
37+
class WeightedNormedDifferenceLoss(ScoringRule):
38+
def __init__(
39+
self,
40+
weighting_function: Callable,
41+
k: int = 2,
42+
name: str = "weighted_normed_difference",
43+
):
44+
super().__init__(name)
45+
46+
if weighting_function:
47+
self.weighting_function = weighting_function
48+
else:
49+
self.weighting_function = lambda input: 1
50+
self.k = k
51+
self.target_shape = (1,)
52+
53+
def score(self, target: Tensor, reference: Tensor) -> Tensor:
54+
pointwise_differance = target - reference[:, None, :]
55+
score = self.weighting_function(reference) * keras.ops.absolute(pointwise_differance) ** self.k
56+
score = keras.ops.mean(score)
57+
return score
58+
59+
60+
class QuantileLoss(ScoringRule):
61+
def __init__(
62+
self,
63+
quantile_levels: Sequence[float] = [0.1, 0.5, 0.9],
64+
name: str = "quantile",
65+
):
66+
super().__init__(name)
67+
self.quantile_levels = keras.ops.convert_to_tensor(quantile_levels)
68+
self.target_shape = (len(self.quantile_levels),)
69+
70+
def score(self, target: Tensor, reference: Tensor) -> Tensor:
71+
pointwise_differance = target - reference[:, None, :]
72+
73+
score = pointwise_differance * (
74+
keras.ops.cast(pointwise_differance > 0, float) - self.quantile_levels[None, :, None]
75+
)
76+
score = keras.ops.mean(score)
77+
return score

0 commit comments

Comments
 (0)