Skip to content

Commit 26f6499

Browse files
committed
Docs for building PointInferenceNetwork from scores
1 parent c2ca810 commit 26f6499

File tree

2 files changed

+79
-7
lines changed

2 files changed

+79
-7
lines changed

bayesflow/networks/point_inference_network.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@ def __init__(
2323
subnet: str | type = "mlp",
2424
**kwargs,
2525
):
26-
super().__init__(
27-
**keras_kwargs(kwargs)
28-
) # TODO: need for bf.utils.keras_kwargs in regular InferenceNetwork class? seems to be a bug
26+
super().__init__(**keras_kwargs(kwargs))
2927

3028
self.scores = scores
3129

@@ -38,6 +36,16 @@ def __init__(
3836
self.config["scores"] = serialize(self.scores)
3937

4038
def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
39+
"""Builds all network components based on shapes of conditions and targets.
40+
41+
For each score, corresponding estimation heads are constructed.
42+
There are two steps in this:
43+
44+
#. Request a dictionary of names and output shapes of required heads from the score.
45+
#. Then for each required head, request corresponding head networks from the score.
46+
47+
Since the score is in charge of constructing heads, this allows for convenient yet flexible building.
48+
"""
4149
if conditions_shape is None: # unconditional estimation uses a fixed input vector
4250
input_shape = (1, 1)
4351
else:

bayesflow/scores/scores.py

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ def from_config(cls, config):
5757

5858
return cls(**config)
5959

60-
def get_head_shapes_from_target_shape(self, target_shape):
60+
def get_head_shapes_from_target_shape(self, target_shape: Shape) -> dict[str, Shape]:
61+
"""Request a dictionary of names and output shapes of required heads from the score."""
6162
raise NotImplementedError
6263

6364
def get_subnet(self, key: str) -> keras.Layer:
@@ -79,23 +80,86 @@ def get_subnet(self, key: str) -> keras.Layer:
7980
else:
8081
return find_network(self.subnets[key], **self.subnets_kwargs.get(key, {}))
8182

82-
def get_link(self, key: str):
83+
def get_link(self, key: str) -> keras.Layer:
84+
"""For a specified key, request a link from network output to estimation target.
85+
86+
Parameters
87+
----------
88+
key : str
89+
Name of head for which to request a link.
90+
91+
Returns
92+
-------
93+
link : keras.Layer
94+
Activation function linking network output to estimation target.
95+
"""
8396
if key not in self.links.keys():
8497
return keras.layers.Activation("linear")
8598
elif isinstance(self.links[key], str):
8699
return keras.layers.Activation(self.links[key])
87100
else:
88101
return self.links[key]
89102

90-
def get_head(self, key: str, shape: Shape):
103+
def get_head(self, key: str, shape: Shape) -> keras.Sequential:
104+
"""For a specified head key and shape, request corresponding head network.
105+
106+
Parameters
107+
----------
108+
key : str
109+
Name of head for which to request a link.
110+
111+
Returns
112+
-------
113+
head : keras.Sequential
114+
Head network consisting of a learnable projection, a reshape and a link operation
115+
to parameterize estimates.
116+
"""
91117
subnet = self.get_subnet(key)
92118
dense = keras.layers.Dense(units=math.prod(shape))
93119
reshape = keras.layers.Reshape(target_shape=shape)
94120
link = self.get_link(key)
95121
return keras.Sequential([subnet, dense, reshape, link])
96122

97123
def score(self, estimates: dict[str, Tensor], targets: Tensor, weights: Tensor) -> Tensor:
98-
"""Scores a probabilistic estimate based of a distribution based on samples of that distribution."""
124+
"""Scores a batch of probabilistic estimates of distributions based on samples
125+
of the corresponding distributions.
126+
127+
Parameters
128+
----------
129+
estimates : dict[str, Tensor]
130+
Dictionary of estimates.
131+
targets : Tensor
132+
Tensor of samples fromt the true distribution to evaluate the estimates.
133+
134+
Returns
135+
-------
136+
numeric_score : Tensor
137+
Negatively oriented score evaluating the estimates, aggregated for the whole batch.
138+
139+
Examples
140+
--------
141+
The following shows how to score estimates with a ``MeanScore``. All ``ScoringRule`` s follow this pattern,
142+
only differing in the structure of the estimates dictionary.
143+
144+
>>> import keras
145+
... from bayesflow.scores import MeanScore
146+
>>>
147+
>>> # batch of samples from a normal distribution
148+
>>> samples = keras.random.normal(shape=(100,))
149+
>>>
150+
>>> # batch of uninformed (random) estimates
151+
>>> bad_estimates = {"value": keras.random.uniform((100,))}
152+
>>>
153+
>>> # batch of estimates that are closer to the true mean
154+
>>> better_estimates = {"value": keras.random.normal(stddev=0.1, shape=(100,))}
155+
>>>
156+
>>> # calculate the score
157+
>>> scoring_rule = MeanScore()
158+
>>> scoring_rule.score(bad_estimates, samples)
159+
<tf.Tensor: shape=(), dtype=float32, numpy=1.2243813276290894>
160+
>>> scoring_rule.score(better_estimates, samples)
161+
<tf.Tensor: shape=(), dtype=float32, numpy=1.013983130455017>
162+
"""
99163
raise NotImplementedError
100164

101165
def aggregate(self, scores: Tensor, weights: Tensor = None):

0 commit comments

Comments
 (0)