Skip to content

Commit 8268128

Browse files
committed
point estimate of covariance compatible with standardization
1 parent dd24941 commit 8268128

File tree

4 files changed

+36
-28
lines changed

4 files changed

+36
-28
lines changed

bayesflow/approximators/point_approximator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,15 @@ def estimate(
6161
conditions = {k: v for k, v in conditions.items() if k in ContinuousApproximator.CONDITION_KEYS}
6262

6363
estimates = self._estimate(**conditions, **kwargs)
64+
65+
if "inference_variables" in self.standardize:
66+
for score_key, score in self.inference_network.scores.items():
67+
for head_key in estimates[score_key].keys():
68+
trafo_type = score.TRANSFORMATION_TYPE.get(head_key, "rank1+shift")
69+
estimates[score_key][head_key] = self.standardize_layers["inference_variables"](
70+
estimates[score_key][head_key], forward=False, transformation_type=trafo_type
71+
)
72+
6473
estimates = self._apply_inverse_adapter_to_estimates(estimates, **kwargs)
6574

6675
# Optionally split the arrays along the last axis.

bayesflow/networks/standardization/standardization.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def call(
5353
stage: str = "inference",
5454
forward: bool = True,
5555
log_det_jac: bool = False,
56+
transformation_type: str = "rank1+shift",
5657
**kwargs,
5758
) -> Tensor | Sequence[Tensor]:
5859
"""
@@ -75,6 +76,14 @@ def call(
7576
Tensor or Sequence[Tensor]
7677
Transformed tensor, and optionally the log-determinant if `log_det_jac=True`.
7778
"""
79+
msg = """
80+
Non-default transformation (i.e. transformation_type != "rank1+shift")
81+
is not supported for forward or log_det_jac.
82+
"""
83+
if forward or log_det_jac:
84+
if transformation_type != "rank1+shift": # non default transformation
85+
raise ValueError(msg)
86+
7887
flattened = keras.tree.flatten(x)
7988
outputs, log_det_jacs = [], []
8089

@@ -91,7 +100,13 @@ def call(
91100
# we can just replace them with zeros.
92101
out = keras.ops.nan_to_num(out, nan=0.0)
93102
else:
94-
out = mean + std * val
103+
match transformation_type:
104+
case "rank1+shift":
105+
# x_i = x_i' * std + mean
106+
out = val * std + mean
107+
case "rank02":
108+
# x_ij = x_ij * sigma_i * sigma_j
109+
out = val * std * keras.ops.moveaxis(std, -1, -2)
95110

96111
outputs.append(out)
97112

bayesflow/scores/multivariate_normal_score.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,16 @@ class MultivariateNormalScore(ParametricDistributionScore):
2626
For more information see :py:class:`ScoringRule`.
2727
"""
2828

29-
RANK: dict[str, int] = {"covariance": 2}
29+
TRANSFORMATION_TYPE: dict[str, str] = {"covariance": "rank02"}
3030
"""
31-
The covariance matrix is a rank 2 tensor and as such the inverse of the standardization operation is
31+
Marks covariance head to handle de-standardization as for covariant rank-(0,2) tensors.
3232
33-
x = x' * sigma ^ 2
33+
The appropriate inverse of the standardization operation is
3434
35-
Accordingly, covariance is also included in :py:attr:`NO_SHIFT`.
36-
"""
35+
x_ij = x_ij * sigma_i * sigma_j.
3736
38-
NO_SHIFT: tuple[str] = ("covariance",)
37+
For the mean head the default ("rank1+shift") is not overridden.
38+
"""
3939

4040
def __init__(self, dim: int = None, links: dict = None, **kwargs):
4141
super().__init__(links=links, **kwargs)

bayesflow/scores/scoring_rule.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,31 +37,15 @@ class ScoringRule:
3737
with a type of estimate whenever the adapter is applied to them in inverse direction.
3838
"""
3939

40-
RANK: dict[str, int] = {}
40+
TRANSFORMATION_TYPE: dict[str, str] = {"covariance": "rank02"}
4141
"""
42-
Mapping of prediction head names to their tensor rank for inverse standardization.
42+
Defines nonstandard transformation behaviour for de-standardization.
4343
44-
The rank indicates the power to which the standard deviation is raised before being multiplied to some estimate
45-
in standardized space.
44+
The standard transformation
4645
47-
x = x' * sigma ^ rank [ + mean ]
46+
x_i = x_i' * std + mean
4847
49-
If a head is not present in this mapping, a default rank of 1 is assumed.
50-
51-
Typically, if :py:attr:`RANK` is modified for an estimate, it is also included in :py:attr:`NO_SHIFT`.
52-
"""
53-
54-
NO_SHIFT: tuple[str] = tuple()
55-
"""
56-
Names of prediction heads whose estimates should not be shifted when applying inverse standardization.
57-
58-
During inverse standardization, point estimates are typically shifted by the stored mean vector. Any head
59-
listed in this tuple will skip the shift step and only be scaled. By default, this tuple is empty,
60-
meaning all heads will be shifted to undo standardization.
61-
62-
x = x' * sigma ^ rank + mean
63-
64-
See also :py:attr:`RANK`.
48+
is referred to as "rank1+shift". Keys not specified here will fallback to that default.
6549
"""
6650

6751
def __init__(

0 commit comments

Comments
 (0)