Skip to content

Commit d76ca9f

Browse files
committed
Docs for scores module [no-ci]
1 parent 95cd950 commit d76ca9f

File tree

3 files changed

+45
-2
lines changed

3 files changed

+45
-2
lines changed

bayesflow/scores/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Scoring rules for point estimation."""
2+
13
from .scores import (
24
ScoringRule,
35
ParametricDistributionRule,

bayesflow/scores/scores.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,16 @@
1212

1313

1414
class ScoringRule:
15+
"""Base class for scoring rules.
16+
17+
Scoring rules evaluate the quality of statistical predictions based on the values that materialize
18+
when sampling from the true distribution. By minimizing an expected score, estimates with
19+
different properties can be obtained.
20+
21+
To define a custom ``ScoringRule``, inherit from this class and overwrite the score method.
22+
For proper serialization, any new constructor arguments must be taken care of in a `get_config` method.
23+
"""
24+
1525
def __init__(
1626
self,
1727
subnets: dict[str, str | type] = None,
@@ -87,6 +97,11 @@ def aggregate(self, scores: Tensor, weights: Tensor = None):
8797

8898

8999
class NormedDifferenceScore(ScoringRule):
100+
r""":math:`S(\hat \theta, \theta; k) = | \hat \theta - \theta |^k`
101+
102+
Scores a point estimate with the k-norm of the error.
103+
"""
104+
90105
def __init__(
91106
self,
92107
k: int,
@@ -118,18 +133,35 @@ def get_config(self):
118133

119134

120135
class MedianScore(NormedDifferenceScore):
136+
r""":math:`S(\hat \theta, \theta) = | \hat \theta - \theta |`
137+
138+
Scores a predicted median with the absolute error score.
139+
"""
140+
121141
def __init__(self, **kwargs):
122142
super().__init__(k=1, **kwargs)
123143
self.config = {}
124144

125145

126146
class MeanScore(NormedDifferenceScore):
147+
r""":math:`S(\hat \theta, \theta) = | \hat \theta - \theta |^2`
148+
149+
Scores a predicted mean with the squared error score.
150+
"""
151+
127152
def __init__(self, **kwargs):
128153
super().__init__(k=2, **kwargs)
129154
self.config = {}
130155

131156

132157
class QuantileScore(ScoringRule):
158+
r""":math:`S(\hat \theta_i, \theta; \tau_i)
159+
= (\hat \theta_i - \theta)(\mathbf{1}_{\hat \theta - \theta > 0} - \tau_i)`
160+
161+
Scores predicted quantiles :math:`\hat \theta_i` with the quantile score
162+
to match the quantile levels :math:`\hat \tau_i`.
163+
"""
164+
133165
def __init__(self, q: Sequence[float] = None, links=None, **kwargs):
134166
super().__init__(links=links, **kwargs)
135167
if q is None:
@@ -165,8 +197,10 @@ def score(self, estimates: dict[str, Tensor], targets: Tensor, weights: Tensor =
165197

166198

167199
class ParametricDistributionRule(ScoringRule):
168-
"""
169-
TODO
200+
r""":math:`S(\hat p_\phi, \theta; k) = \log(\hat p_\phi(\theta))`
201+
202+
Scores a predicted parametric probability distribution with the log-score
203+
of the probability of the materialized value.
170204
"""
171205

172206
def __init__(self, **kwargs):
@@ -186,6 +220,11 @@ def score(self, estimates: dict[str, Tensor], targets: Tensor, weights: Tensor =
186220

187221

188222
class MultivariateNormalScore(ParametricDistributionRule):
223+
r""":math:`S(\hat p_{\mu, \Sigma}, \theta; k) = \log( \mathcal N (\theta; \mu, \Sigma))`
224+
225+
Scores a predicted mean and covariance matrix with the log-score of the probability of the materialized value.
226+
"""
227+
189228
def __init__(self, D: int = None, links=None, **kwargs):
190229
super().__init__(links=links, **kwargs)
191230
self.D = D

docsrc/source/api/bayesflow.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
bayesflow.diagnostics
1414
bayesflow.distributions
1515
bayesflow.experimental
16+
bayesflow.links
1617
bayesflow.metrics
1718
bayesflow.networks
1819
bayesflow.simulators
20+
bayesflow.scores
1921
bayesflow.types
2022
bayesflow.utils
2123
bayesflow.workflows

0 commit comments

Comments
 (0)