Skip to content

Commit f21e6ef

Browse files
authored
Stabilizing multivariate normal approximation (#380)
* Better parameterization of covariance matrices * Fix format string * Test for invertibility of positive definite link output * Allow estimation of univariate MVN * Remove commented lines * Minor changes to comments and docstring for fill_triangular_matrix * Test coverage for unconditional MVNScore.sample * Remove instability warning MultivariateNormalScore * Remove commented numpy import * Fix dtype of dummy conditions if inference variables are available * Tuple conversion in case batch_shape is a list * Conversion to numpy before calling numpy operations * More detailed docs and renamed the transformation warning attribute * Doc string detail * Remove untested comment for PointInferenceNetwork.sample() * Relax type hints for ContinuousApproximator.log_prob * Support log-prob in PointApproximator * Remove comment stating log prob was untested * Fix typo * Transformation warning using a class variable; docstring links
1 parent 1474bb6 commit f21e6ef

File tree

14 files changed

+302
-84
lines changed

14 files changed

+302
-84
lines changed

bayesflow/approximators/continuous_approximator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def _sample(
338338
**filter_kwargs(kwargs, self.inference_network.sample),
339339
)
340340

341-
def log_prob(self, data: dict[str, np.ndarray], **kwargs) -> np.ndarray:
341+
def log_prob(self, data: dict[str, np.ndarray], **kwargs) -> np.ndarray | dict[str, np.ndarray]:
342342
"""
343343
Computes the log-probability of given data under the model. The `data` dictionary is preprocessed using the
344344
`adapter`. Log-probabilities are returned as NumPy arrays.
@@ -358,7 +358,7 @@ def log_prob(self, data: dict[str, np.ndarray], **kwargs) -> np.ndarray:
358358
data = self.adapter(data, strict=False, stage="inference", **kwargs)
359359
data = keras.tree.map_structure(keras.ops.convert_to_tensor, data)
360360
log_prob = self._log_prob(**data, **kwargs)
361-
log_prob = keras.ops.convert_to_numpy(log_prob)
361+
log_prob = keras.tree.map_structure(keras.ops.convert_to_numpy, log_prob)
362362

363363
return log_prob
364364

@@ -368,7 +368,7 @@ def _log_prob(
368368
inference_conditions: Tensor = None,
369369
summary_variables: Tensor = None,
370370
**kwargs,
371-
) -> Tensor:
371+
) -> Tensor | dict[str, Tensor]:
372372
if self.summary_network is None:
373373
if summary_variables is not None:
374374
raise ValueError("Cannot use summary variables without a summary network.")

bayesflow/approximators/point_approximator.py

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
)
66

77
from bayesflow.types import Tensor
8-
from bayesflow.utils import filter_kwargs, split_arrays, squeeze_inner_estimates_dict
8+
from bayesflow.utils import filter_kwargs, split_arrays, squeeze_inner_estimates_dict, logging
99
from .continuous_approximator import ContinuousApproximator
1010

1111

@@ -14,8 +14,9 @@ class PointApproximator(ContinuousApproximator):
1414
"""
1515
A workflow for fast amortized point estimation of a conditional distribution.
1616
17-
The distribution is approximated by point estimators, parameterized by a feed-forward `PointInferenceNetwork`.
18-
Conditions can be compressed by an optional `SummaryNetwork` or used directly as input to the inference network.
17+
The distribution is approximated by point estimators, parameterized by a feed-forward
18+
:class:`bayesflow.networks.PointInferenceNetwork`. Conditions can be compressed by an optional summary network
19+
(inheriting from :class:`bayesflow.networks.SummaryNetwork`) or used directly as input to the inference network.
1920
"""
2021

2122
def estimate(
@@ -89,7 +90,7 @@ def sample(
8990
for the sampling process.
9091
split : bool, optional
9192
If True, the sampled arrays are split along the last axis, by default False.
92-
Currently not supported for `PointApproximator`.
93+
Currently not supported for :class:`PointApproximator` .
9394
**kwargs
9495
Additional keyword arguments passed to underlying processing functions.
9596
@@ -111,14 +112,50 @@ def sample(
111112
if split:
112113
raise NotImplementedError("split=True is currently not supported for `PointApproximator`.")
113114
samples = split_arrays(samples, axis=-1)
114-
# Squeeze samples if there's only one key-value pair.
115-
samples = self._squeeze_samples(samples)
115+
# Squeeze sample dictionary if there's only one key-value pair.
116+
samples = self._squeeze_parametric_score_major_dict(samples)
116117

117118
return samples
118119

120+
def log_prob(
121+
self,
122+
*,
123+
data: dict[str, np.ndarray],
124+
**kwargs,
125+
) -> np.ndarray | dict[str, np.ndarray]:
126+
"""
127+
Computes the log-probability of given data under the parametric distribution(s) for given input conditions.
128+
129+
Parameters
130+
----------
131+
data : dict[str, np.ndarray]
132+
A dictionary mapping variable names to arrays representing the inference conditions and variables.
133+
**kwargs
134+
Additional keyword arguments passed to underlying processing functions.
135+
136+
Returns
137+
-------
138+
log_prob : np.ndarray or dict[str, np.ndarray]
139+
Log-probabilities of the distribution
140+
`p(inference_variables | inference_conditions, h(summary_conditions))` for all parametric scoring rules.
141+
142+
If only one parametric score is available, output is an array of log-probabilities.
143+
144+
Output is a dictionary if multiple parametric scores are available.
145+
Then, each key is the name of a score and values are corresponding log-probabilities.
146+
147+
Log-probabilities have shape (num_datasets,).
148+
"""
149+
log_prob = super().log_prob(data=data, **kwargs)
150+
# Squeeze log probabilities dictionary if there's only one key-value pair.
151+
log_prob = self._squeeze_parametric_score_major_dict(log_prob)
152+
153+
return log_prob
154+
119155
def _prepare_conditions(self, conditions: dict[str, np.ndarray], **kwargs) -> dict[str, Tensor]:
120156
"""Adapts and converts the conditions to tensors."""
121157
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
158+
conditions.pop("inference_variables", None)
122159
return keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
123160

124161
def _apply_inverse_adapter_to_estimates(
@@ -130,6 +167,12 @@ def _apply_inverse_adapter_to_estimates(
130167
for score_key, score_val in estimates.items():
131168
processed[score_key] = {}
132169
for head_key, estimate in score_val.items():
170+
if head_key in self.inference_network.scores[score_key].NOT_TRANSFORMING_LIKE_VECTOR_WARNING:
171+
logging.warning(
172+
f"Estimate '{score_key}.{head_key}' is marked to not transform like a vector. "
173+
f"It was treated like a vector by the adapter. Handle '{head_key}' estimates with care."
174+
)
175+
133176
adapted = self.adapter(
134177
{"inference_variables": estimate},
135178
inverse=True,
@@ -180,8 +223,10 @@ def _squeeze_estimates(
180223
}
181224
return squeezed
182225

183-
def _squeeze_samples(self, samples: dict[str, np.ndarray]) -> np.ndarray or dict[str, np.ndarray]:
184-
"""Squeezes the samples dictionary to just the value if there is only one key-value pair."""
226+
def _squeeze_parametric_score_major_dict(
227+
self, samples: dict[str, np.ndarray]
228+
) -> np.ndarray or dict[str, np.ndarray]:
229+
"""Squeezes the dictionary to just the value if there is only one key-value pair."""
185230
if len(samples) == 1:
186231
return next(iter(samples.values())) # Extract and return the only item's value
187232
return samples

bayesflow/links/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from .ordered import Ordered
44
from .ordered_quantiles import OrderedQuantiles
5-
from .positive_semi_definite import PositiveSemiDefinite
5+
from .positive_definite import PositiveDefinite
66

77
from ..utils._docs import _add_imports_to_all
88

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import keras
2+
3+
from keras.saving import register_keras_serializable as serializable
4+
5+
from bayesflow.types import Tensor
6+
from bayesflow.utils import keras_kwargs, fill_triangular_matrix
7+
8+
9+
@serializable(package="bayesflow.links")
10+
class PositiveDefinite(keras.Layer):
11+
"""Activation function to link from flat elements of a lower triangular matrix to a positive definite matrix."""
12+
13+
def __init__(self, **kwargs):
14+
super().__init__(**keras_kwargs(kwargs))
15+
self.built = True
16+
17+
def call(self, inputs: Tensor) -> Tensor:
18+
# Build cholesky factor from inputs
19+
L = fill_triangular_matrix(inputs, positive_diag=True)
20+
21+
# calculate positive definite matrix from cholesky factors
22+
psd = keras.ops.matmul(
23+
L,
24+
keras.ops.moveaxis(L, -2, -1), # L transposed
25+
)
26+
return psd
27+
28+
def compute_output_shape(self, input_shape):
29+
m = input_shape[-1]
30+
n = int((0.25 + 2.0 * m) ** 0.5 - 0.5)
31+
return input_shape[:-1] + (n, n)
32+
33+
def compute_input_shape(self, output_shape):
34+
"""
35+
Returns the shape of parameterization of a cholesky factor triangular matrix.
36+
37+
There are m nonzero elements of a lower triangular nxn matrix with m = n * (n + 1) / 2.
38+
39+
Example
40+
-------
41+
>>> PositiveDefinite().compute_output_shape((None, 3, 3))
42+
6
43+
"""
44+
n = output_shape[-1]
45+
m = int(n * (n + 1) / 2)
46+
return output_shape[:-2] + (m,)

bayesflow/links/positive_semi_definite.py

Lines changed: 0 additions & 20 deletions
This file was deleted.

bayesflow/networks/point_inference_network.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,9 @@ def call(
132132
if xz is None and not self.built:
133133
raise ValueError("Cannot build inference network without inference variables.")
134134
if conditions is None: # unconditional estimation uses a fixed input vector
135-
conditions = keras.ops.convert_to_tensor([[1.0]], dtype=keras.ops.dtype(xz))
135+
conditions = keras.ops.convert_to_tensor(
136+
[[1.0]], dtype=keras.ops.dtype(xz) if xz is not None else "float32"
137+
)
136138

137139
# pass conditions to the shared subnet
138140
output = self.subnet(conditions, training=training)
@@ -165,7 +167,6 @@ def compute_metrics(
165167

166168
return metrics | {"loss": neg_score}
167169

168-
# WIP: untested draft of sample method
169170
@allow_batch_size
170171
def sample(self, batch_shape: Shape, conditions: Tensor = None) -> dict[str, Tensor]:
171172
"""
@@ -199,7 +200,6 @@ def sample(self, batch_shape: Shape, conditions: Tensor = None) -> dict[str, Ten
199200

200201
return samples
201202

202-
# WIP: untested draft of log_prob method
203203
def log_prob(self, samples: Tensor, conditions: Tensor = None, **kwargs) -> dict[str, Tensor]:
204204
output = self.subnet(conditions)
205205
log_probs = {}

bayesflow/scores/multivariate_normal_score.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
from keras.saving import register_keras_serializable as serializable
55

66
from bayesflow.types import Shape, Tensor
7-
from bayesflow.links import PositiveSemiDefinite
8-
from bayesflow.utils import logging
7+
from bayesflow.links import PositiveDefinite
98

109
from .parametric_distribution_score import ParametricDistributionScore
1110

@@ -17,14 +16,23 @@ class MultivariateNormalScore(ParametricDistributionScore):
1716
Scores a predicted mean and covariance matrix with the log-score of the probability of the materialized value.
1817
"""
1918

19+
NOT_TRANSFORMING_LIKE_VECTOR_WARNING = ("covariance",)
20+
"""
21+
Marks head for covariance matrix as an exception for adapter transformations.
22+
23+
This variable contains names of prediction heads that should lead to a warning when the adapter is applied
24+
in inverse direction to them.
25+
26+
For more information see :class:`ScoringRule`.
27+
"""
28+
2029
def __init__(self, dim: int = None, links: dict = None, **kwargs):
2130
super().__init__(links=links, **kwargs)
2231

2332
self.dim = dim
24-
self.links = links or {"covariance": PositiveSemiDefinite()}
25-
self.config = {"dim": dim}
33+
self.links = links or {"covariance": PositiveDefinite()}
2634

27-
logging.warning("MultivariateNormalScore is unstable.")
35+
self.config = {"dim": dim}
2836

2937
def get_config(self):
3038
base_config = super().get_config()
@@ -60,12 +68,12 @@ def log_prob(self, x: Tensor, mean: Tensor, covariance: Tensor) -> Tensor:
6068
A tensor containing the log probability densities for each sample in `x` under the
6169
given Gaussian distribution.
6270
"""
63-
diff = x[:, None, :] - mean
64-
inv_covariance = keras.ops.inv(covariance)
71+
diff = x - mean
72+
precision = keras.ops.inv(covariance)
6573
log_det_covariance = keras.ops.slogdet(covariance)[1] # Only take the log of the determinant part
6674

6775
# Compute the quadratic term in the exponential of the multivariate Gaussian
68-
quadratic_term = keras.ops.einsum("...i,...ij,...j->...", diff, inv_covariance, diff)
76+
quadratic_term = keras.ops.einsum("...i,...ij,...j->...", diff, precision, diff)
6977

7078
# Compute the log probability density
7179
log_prob = -0.5 * (self.dim * keras.ops.log(2 * math.pi) + log_det_covariance + quadratic_term)
@@ -97,6 +105,8 @@ def sample(self, batch_shape: Shape, mean: Tensor, covariance: Tensor) -> Tensor
97105
Tensor
98106
A tensor of shape (batch_size, num_samples, D) containing the generated samples.
99107
"""
108+
if len(batch_shape) == 1:
109+
batch_shape = (1,) + tuple(batch_shape)
100110
batch_size, num_samples = batch_shape
101111
dim = keras.ops.shape(mean)[-1]
102112
if keras.ops.shape(mean) != (batch_size, dim):

bayesflow/scores/parametric_distribution_score.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,5 +51,4 @@ def score(self, estimates: dict[str, Tensor], targets: Tensor, weights: Tensor =
5151
"""
5252
scores = -self.log_prob(x=targets, **estimates)
5353
score = self.aggregate(scores, weights)
54-
# multipy to mitigate instability due to relatively high values of parametric score
55-
return score * 0.01
54+
return score

0 commit comments

Comments
 (0)