Skip to content

Commit 34c7f2a

Browse files
committed
Better parameterization of covariance matrices
1 parent 69e2387 commit 34c7f2a

File tree

12 files changed

+192
-61
lines changed

12 files changed

+192
-61
lines changed

bayesflow/approximators/point_approximator.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
)
66

77
from bayesflow.types import Tensor
8-
from bayesflow.utils import filter_kwargs, split_arrays, squeeze_inner_estimates_dict
8+
from bayesflow.utils import filter_kwargs, split_arrays, squeeze_inner_estimates_dict, logging
99
from .continuous_approximator import ContinuousApproximator
1010

1111

@@ -119,6 +119,7 @@ def sample(
119119
def _prepare_conditions(self, conditions: dict[str, np.ndarray], **kwargs) -> dict[str, Tensor]:
120120
"""Adapts and converts the conditions to tensors."""
121121
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
122+
conditions.pop("inference_variables", None)
122123
return keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
123124

124125
def _apply_inverse_adapter_to_estimates(
@@ -130,6 +131,12 @@ def _apply_inverse_adapter_to_estimates(
130131
for score_key, score_val in estimates.items():
131132
processed[score_key] = {}
132133
for head_key, estimate in score_val.items():
134+
if head_key in self.inference_network.scores[score_key].not_transforming_like_vector:
135+
logging.warning(
136+
f"Estimate '{score_key}.{head_key}' is marked to not transform like a vector. "
137+
"It was treated like a vector by the adapter. Handle '{head_key}' estimates with care."
138+
)
139+
133140
adapted = self.adapter(
134141
{"inference_variables": estimate},
135142
inverse=True,

bayesflow/links/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from .ordered import Ordered
44
from .ordered_quantiles import OrderedQuantiles
5-
from .positive_semi_definite import PositiveSemiDefinite
5+
from .positive_definite import PositiveDefinite
66

77
from ..utils._docs import _add_imports_to_all
88

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import keras
2+
3+
# import numpy as np
4+
from keras.saving import register_keras_serializable as serializable
5+
6+
from bayesflow.types import Tensor
7+
from bayesflow.utils import keras_kwargs, fill_triangular_matrix
8+
9+
10+
@serializable(package="bayesflow.links")
11+
class PositiveDefinite(keras.Layer):
12+
"""Activation function to link from flat elements of a lower triangular matrix to a positive definite matrix."""
13+
14+
def __init__(self, **kwargs):
15+
super().__init__(**keras_kwargs(kwargs))
16+
self.built = True
17+
18+
def call(self, inputs: Tensor) -> Tensor:
19+
# Build cholesky factor from inputs
20+
L = fill_triangular_matrix(inputs, positive_diag=True)
21+
22+
# diagonal_mask = keras.ops.identity(L.shape[-1]) > 0
23+
# L[..., diagonal_mask] = keras.activations.softplus(L[..., diagonal_mask])
24+
# L += keras.ops.identity(L.shape[-1]) * 2
25+
# L *= keras.ops.sign(keras.ops.diagonal(L, axis1=-1))[..., None] # ensure positive diagonal entries
26+
27+
# calculate positive definite matrix from cholesky factors
28+
psd = keras.ops.matmul(
29+
L,
30+
keras.ops.moveaxis(L, -2, -1), # L transposed
31+
)
32+
return psd
33+
34+
def compute_output_shape(self, input_shape):
35+
m = input_shape[-1]
36+
n = int((0.25 + 2.0 * m) ** 0.5 - 0.5)
37+
return input_shape[:-1] + (n, n)
38+
39+
def compute_input_shape(self, output_shape):
40+
"""
41+
Returns the shape of parameterization of a cholesky factor triangular matrix.
42+
43+
There are m nonzero elements of a lower triangular nxn matrix with m = n * (n + 1) / 2.
44+
45+
Example
46+
-------
47+
>>> PositiveDefinite().compute_output_shape((None, 3, 3))
48+
6
49+
"""
50+
n = output_shape[-1]
51+
m = int(n * (n + 1) / 2)
52+
return output_shape[:-2] + (m,)

bayesflow/links/positive_semi_definite.py

Lines changed: 0 additions & 20 deletions
This file was deleted.

bayesflow/scores/multivariate_normal_score.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from keras.saving import register_keras_serializable as serializable
55

66
from bayesflow.types import Shape, Tensor
7-
from bayesflow.links import PositiveSemiDefinite
7+
from bayesflow.links import PositiveDefinite
88
from bayesflow.utils import logging
99

1010
from .parametric_distribution_score import ParametricDistributionScore
@@ -21,7 +21,11 @@ def __init__(self, dim: int = None, links: dict = None, **kwargs):
2121
super().__init__(links=links, **kwargs)
2222

2323
self.dim = dim
24-
self.links = links or {"covariance": PositiveSemiDefinite()}
24+
self.links = links or {"covariance": PositiveDefinite()}
25+
26+
# mark head for covariance matrix as an exception for adapter transformations
27+
self.not_transforming_like_vector = ["covariance"]
28+
2529
self.config = {"dim": dim}
2630

2731
logging.warning("MultivariateNormalScore is unstable.")
@@ -60,12 +64,12 @@ def log_prob(self, x: Tensor, mean: Tensor, covariance: Tensor) -> Tensor:
6064
A tensor containing the log probability densities for each sample in `x` under the
6165
given Gaussian distribution.
6266
"""
63-
diff = x[:, None, :] - mean
64-
inv_covariance = keras.ops.inv(covariance)
67+
diff = x - mean
68+
precision = keras.ops.inv(covariance)
6569
log_det_covariance = keras.ops.slogdet(covariance)[1] # Only take the log of the determinant part
6670

6771
# Compute the quadratic term in the exponential of the multivariate Gaussian
68-
quadratic_term = keras.ops.einsum("...i,...ij,...j->...", diff, inv_covariance, diff)
72+
quadratic_term = keras.ops.einsum("...i,...ij,...j->...", diff, precision, diff)
6973

7074
# Compute the log probability density
7175
log_prob = -0.5 * (self.dim * keras.ops.log(2 * math.pi) + log_det_covariance + quadratic_term)

bayesflow/scores/parametric_distribution_score.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,5 +51,4 @@ def score(self, estimates: dict[str, Tensor], targets: Tensor, weights: Tensor =
5151
"""
5252
scores = -self.log_prob(x=targets, **estimates)
5353
score = self.aggregate(scores, weights)
54-
# multipy to mitigate instability due to relatively high values of parametric score
55-
return score * 0.01
54+
return score

bayesflow/scores/scoring_rule.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ def __init__(
2929
self.subnets_kwargs = subnets_kwargs or {}
3030
self.links = links or {}
3131

32+
self.not_transforming_like_vector = []
33+
3234
self.config = {"subnets_kwargs": self.subnets_kwargs}
3335

3436
def get_config(self):
@@ -95,14 +97,14 @@ def get_link(self, key: str) -> keras.Layer:
9597
else:
9698
return self.links[key]
9799

98-
def get_head(self, key: str, shape: Shape) -> keras.Sequential:
100+
def get_head(self, key: str, output_shape: Shape) -> keras.Sequential:
99101
"""For a specified head key and shape, request corresponding head network.
100102
101103
Parameters
102104
----------
103105
key : str
104106
Name of head for which to request a link.
105-
shape: Shape
107+
output_shape: Shape
106108
The necessary shape for the point estimators.
107109
108110
Returns
@@ -111,10 +113,19 @@ def get_head(self, key: str, shape: Shape) -> keras.Sequential:
111113
Head network consisting of a learnable projection, a reshape and a link operation
112114
to parameterize estimates.
113115
"""
114-
subnet = self.get_subnet(key)
115-
dense = keras.layers.Dense(units=math.prod(shape))
116-
reshape = keras.layers.Reshape(target_shape=shape)
116+
# initialize head components back to front
117117
link = self.get_link(key)
118+
119+
# link input shape can differ from output shape
120+
if hasattr(link, "compute_input_shape"):
121+
link_input_shape = link.compute_input_shape(output_shape)
122+
else:
123+
link_input_shape = output_shape
124+
125+
reshape = keras.layers.Reshape(target_shape=link_input_shape)
126+
dense = keras.layers.Dense(units=math.prod(link_input_shape))
127+
subnet = self.get_subnet(key)
128+
118129
return keras.Sequential([subnet, dense, reshape, link])
119130

120131
def score(self, estimates: dict[str, Tensor], targets: Tensor, weights: Tensor) -> Tensor:

bayesflow/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
tile_axis,
6767
tree_concatenate,
6868
tree_stack,
69+
fill_triangular_matrix,
6970
)
7071
from .validators import check_lengths_same
7172
from .workflow_utils import find_inference_network, find_summary_network

bayesflow/utils/tensor_utils.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,3 +277,80 @@ def stack(*items):
277277
return keras.ops.stack(items, axis=axis)
278278

279279
return keras.tree.map_structure(stack, *structures)
280+
281+
282+
def fill_triangular_matrix(x: Tensor, upper: bool = False, positive_diag: bool = False):
283+
"""
284+
Reshapes a batch of matrix entries into a triangular matrix (either upper or lower).
285+
286+
Note: If final axis has length 1, this simply reshapes to (batch_size, 1, 1) and optionally applies softplus.
287+
288+
Parameters
289+
----------
290+
x : Tensor of shape (batch_size, m)
291+
Batch of flattened nonzero matrix elements for triangular matrix.
292+
upper : bool
293+
Return upper triangular matrix if True, else lower triangular matrix. Default is False.
294+
positive_diag : bool
295+
Whether to apply a softplus operation to diagonal elements. Default is False.
296+
297+
Returns
298+
-------
299+
Tensor of shape (batch_size, n, n)
300+
Batch of triangular matrices with m = n * (n + 1) / 2 unique nonzero elements.
301+
302+
Raises
303+
------
304+
ValueError
305+
If provided nonzero elements do not correspond to possible triangular matrix shape
306+
(n,n) with n = sqrt( 1/4 + 2 * m) - 1/2 due to m = n * (n + 1) / 2.
307+
"""
308+
batch_shape = x.shape[:-1]
309+
m = x.shape[-1]
310+
311+
if m == 1:
312+
y = keras.ops.reshape(x, (-1, 1, 1))
313+
if positive_diag:
314+
y = keras.activations.softplus(y)
315+
return y
316+
317+
# Calculate matrix shape
318+
n = (0.25 + 2 * m) ** 0.5 - 0.5
319+
if not np.isclose(np.floor(n), n):
320+
raise ValueError(f"Input right-most shape ({m}) does not correspond to a triangular matrix.")
321+
else:
322+
n = int(n)
323+
324+
# Trick: Create triangular matrix by concatenating with a flipped version of its tail, then reshape.
325+
x_tail = keras.ops.take(x, indices=list(range((m - (n**2 - m)), x.shape[-1])), axis=-1)
326+
if not upper:
327+
y = keras.ops.concatenate([x_tail, keras.ops.flip(x, axis=-1)], axis=len(batch_shape))
328+
y = keras.ops.reshape(y, (-1, n, n))
329+
y = keras.ops.tril(y) # TODO: fails with tensorflow
330+
331+
if positive_diag:
332+
y_offdiag = keras.ops.tril(y, k=-1)
333+
y_diag = keras.ops.tril(
334+
keras.ops.triu( # carve out diagonal, by setting upper and lower offdiagonals to zero
335+
keras.activations.softplus(y)
336+
), # apply softplus to enforce positivity
337+
)
338+
y = y_diag + y_offdiag
339+
340+
else:
341+
y = keras.ops.concatenate([x, keras.ops.flip(x_tail, axis=-1)], axis=len(batch_shape))
342+
y = keras.ops.reshape(y, (-1, n, n))
343+
y = keras.ops.triu(
344+
y,
345+
)
346+
347+
if positive_diag:
348+
y_offdiag = keras.ops.triu(y, k=1)
349+
y_diag = keras.ops.tril(
350+
keras.ops.triu( # carve out diagonal, by setting upper and lower offdiagonals to zero
351+
keras.activations.softplus(y)
352+
), # apply softplus to enforce positivity
353+
)
354+
y = y_diag + y_offdiag
355+
356+
return y

tests/test_links/conftest.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def num_variables():
1515

1616
@pytest.fixture()
1717
def generic_preactivation(batch_size):
18-
return keras.ops.ones((batch_size, 4, 4))
18+
return keras.ops.ones((batch_size, 6))
1919

2020

2121
@pytest.fixture()
@@ -33,18 +33,18 @@ def ordered_quantiles():
3333

3434

3535
@pytest.fixture()
36-
def positive_semi_definite():
37-
from bayesflow.links import PositiveSemiDefinite
36+
def positive_definite():
37+
from bayesflow.links import PositiveDefinite
3838

39-
return PositiveSemiDefinite()
39+
return PositiveDefinite()
4040

4141

4242
@pytest.fixture()
4343
def linear():
4444
return keras.layers.Activation("linear")
4545

4646

47-
@pytest.fixture(params=["ordered", "ordered_quantiles", "positive_semi_definite", "linear"], scope="function")
47+
@pytest.fixture(params=["ordered", "ordered_quantiles", "positive_definite", "linear"], scope="function")
4848
def link(request):
4949
return request.getfixturevalue(request.param)
5050

@@ -84,6 +84,6 @@ def unordered(batch_size, num_quantiles, num_variables):
8484
return keras.random.normal((batch_size, num_quantiles, num_variables))
8585

8686

87-
@pytest.fixture()
88-
def random_matrix_batch(batch_size, num_variables):
89-
return keras.random.normal((batch_size, num_variables, num_variables))
87+
# @pytest.fixture()
88+
# def random_matrix_batch(batch_size, num_variables):
89+
# return keras.random.normal((batch_size, num_variables, num_variables))

0 commit comments

Comments
 (0)