Skip to content

Commit 392d9f7

Browse files
committed
Add class attributes to inform proper standardization
1 parent b2bfeea commit 392d9f7

File tree

2 files changed

+40
-3
lines changed

2 files changed

+40
-3
lines changed

bayesflow/scores/multivariate_normal_score.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,17 @@ class MultivariateNormalScore(ParametricDistributionScore):
2626
For more information see :py:class:`ScoringRule`.
2727
"""
2828

29+
RANK: dict[str, int] = {"covariance": 2}
30+
"""
31+
The covariance matrix is a rank 2 tensor and as such the inverse of the standardization operation is
32+
33+
x = x' * sigma ^ 2
34+
35+
Accordingly, covariance is also included in :py:attr:`NO_SHIFT`.
36+
"""
37+
38+
NO_SHIFT: tuple[str] = ("covariance",)
39+
2940
def __init__(self, dim: int = None, links: dict = None, **kwargs):
3041
super().__init__(links=links, **kwargs)
3142

bayesflow/scores/scoring_rule.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,9 @@ class ScoringRule:
2626
and covariance simultaneously.
2727
"""
2828

29-
NOT_TRANSFORMING_LIKE_VECTOR_WARNING = tuple()
29+
NOT_TRANSFORMING_LIKE_VECTOR_WARNING: tuple[str] = tuple()
3030
"""
31-
This variable contains names of prediction heads that should lead to a warning when the adapter is applied
32-
in inverse direction to them.
31+
Names of prediction heads for which to warn if the adapter is called on their estimates in inverse direction.
3332
3433
Prediction heads can output estimates in spaces other than the target distribution space.
3534
To such estimates the adapter cannot be straightforwardly applied in inverse direction,
@@ -38,6 +37,33 @@ class ScoringRule:
3837
with a type of estimate whenever the adapter is applied to them in inverse direction.
3938
"""
4039

40+
RANK: dict[str, int] = {}
41+
"""
42+
Mapping of prediction head names to their tensor rank for inverse standardization.
43+
44+
The rank indicates the power to which the standard deviation is raised before being multiplied to some estimate
45+
in standardized space.
46+
47+
x = x' * sigma ^ rank [ + mean ]
48+
49+
If a head is not present in this mapping, a default rank of 1 is assumed.
50+
51+
Typically, if :py:attr:`RANK` is modified for an estimate, it is also included in :py:attr:`NO_SHIFT`.
52+
"""
53+
54+
NO_SHIFT: tuple[str] = tuple()
55+
"""
56+
Names of prediction heads whose estimates should not be shifted when applying inverse standardization.
57+
58+
During inverse standardization, point estimates are typically shifted by the stored mean vector. Any head
59+
listed in this tuple will skip the shift step and only be scaled. By default, this tuple is empty,
60+
meaning all heads will be shifted to undo standardization.
61+
62+
x = x' * sigma ^ rank + mean
63+
64+
See also :py:attr:`RANK`.
65+
"""
66+
4167
def __init__(
4268
self,
4369
subnets: dict[str, str | type] = None,

0 commit comments

Comments
 (0)