Skip to content
3 changes: 3 additions & 0 deletions bayesflow/networks/standardization/standardization.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ def call(
case "left_side_scale":
# x_ij = sigma_i * x_ij'
out = val * keras.ops.moveaxis(std, -1, -2)
case "right_side_scale_inverse":
# x_ij = x_ij' / sigma_j
out = val / std
case _:
out = val

Expand Down
66 changes: 33 additions & 33 deletions bayesflow/scores/multivariate_normal_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,26 @@
class MultivariateNormalScore(ParametricDistributionScore):
r""":math:`S(\hat p_{\mu, \Sigma}, \theta; k) = -\log( \mathcal N (\theta; \mu, \Sigma))`

Scores a predicted mean and (Cholesky factor of the) covariance matrix with the log-score of the probability
of the materialized value.
Scores a predicted mean and lower-triangular Cholesky factor :math:`L` of the precision matrix :math:`P`
with the log-score of the probability of the materialized value. The precision matrix is
the inverse of the covariance matrix, :math:`L^T L = P = \Sigma^{-1}`.
"""

NOT_TRANSFORMING_LIKE_VECTOR_WARNING = ("cov_chol",)
NOT_TRANSFORMING_LIKE_VECTOR_WARNING = ("precision_cholesky_factor",)
"""
Marks head for covariance matrix Cholesky factor as an exception for adapter transformations.
Marks head for precision matrix Cholesky factor as an exception for adapter transformations.

This variable contains names of prediction heads that should lead to a warning when the adapter is applied
in inverse direction to them.

For more information see :py:class:`ScoringRule`.
"""

TRANSFORMATION_TYPE: dict[str, str] = {"cov_chol": "left_side_scale"}
TRANSFORMATION_TYPE: dict[str, str] = {"precision_cholesky_factor": "right_side_scale_inverse"}
"""
Marks covariance Cholesky factor head to handle de-standardization as for covariant rank-(0,2) tensors.
Marks precision Cholesky factor head to handle de-standardization appropriately.

The appropriate inverse of the standardization operation is

x_ij = sigma_i * x_ij'.
See :py:class:`bayesflow.networks.Standardization` for more information on supported de-standardization options.

For the mean head the default ("location_scale") is not overridden.
"""
Expand All @@ -42,7 +41,7 @@ def __init__(self, dim: int = None, links: dict = None, **kwargs):
super().__init__(links=links, **kwargs)

self.dim = dim
self.links = links or {"cov_chol": CholeskyFactor()}
self.links = links or {"precision_cholesky_factor": CholeskyFactor()}

self.config = {"dim": dim}

Expand All @@ -52,16 +51,16 @@ def get_config(self):

def get_head_shapes_from_target_shape(self, target_shape: Shape) -> dict[str, Shape]:
self.dim = target_shape[-1]
return dict(mean=(self.dim,), cov_chol=(self.dim, self.dim))
return dict(mean=(self.dim,), precision_cholesky_factor=(self.dim, self.dim))

def log_prob(self, x: Tensor, mean: Tensor, cov_chol: Tensor) -> Tensor:
def log_prob(self, x: Tensor, mean: Tensor, precision_cholesky_factor: Tensor) -> Tensor:
"""
Compute the log probability density of a multivariate Gaussian distribution.

This function calculates the log probability density for each sample in `x` under a
multivariate Gaussian distribution with the given `mean` and `cov_chol`.
multivariate Gaussian distribution with the given `mean` and `precision_cholesky_factor`.

The computation includes the determinant of the covariance matrix, its inverse, and the quadratic
The computation includes the determinant of the precision matrix, its inverse, and the quadratic
form in the exponential term of the Gaussian density function.

Parameters
Expand All @@ -71,8 +70,9 @@ def log_prob(self, x: Tensor, mean: Tensor, cov_chol: Tensor) -> Tensor:
The shape should be compatible with broadcasting against `mean`.
mean : Tensor
A tensor representing the mean of the multivariate Gaussian distribution.
covariance : Tensor
A tensor representing the covariance matrix of the multivariate Gaussian distribution.
precision_cholesky_factor : Tensor
A tensor representing the lower-triangular Cholesky factor of the precision matrix
of the multivariate Gaussian distribution.

Returns
-------
Expand All @@ -82,29 +82,27 @@ def log_prob(self, x: Tensor, mean: Tensor, cov_chol: Tensor) -> Tensor:
"""
diff = x - mean

# Calculate precision from Cholesky factors of covariance matrix
cov_chol_inv = keras.ops.inv(cov_chol)
precision = keras.ops.matmul(
keras.ops.swapaxes(cov_chol_inv, -2, -1),
cov_chol_inv,
)

# Compute log determinant, exploiting Cholesky factors
log_det_covariance = keras.ops.log(keras.ops.prod(keras.ops.diagonal(cov_chol, axis1=1, axis2=2), axis=1)) * 2
log_det_covariance = -2 * keras.ops.sum(
keras.ops.log(keras.ops.diagonal(precision_cholesky_factor, axis1=1, axis2=2)), axis=1
)

# Compute the quadratic term in the exponential of the multivariate Gaussian
quadratic_term = keras.ops.einsum("...i,...ij,...j->...", diff, precision, diff)
# Compute the quadratic term in the exponential of the multivariate Gaussian from Cholesky factors
# diff^T * precision_cholesky_factor^T * precision_cholesky_factor * diff
quadratic_term = keras.ops.einsum(
"...i,...ji,...jk,...k->...", diff, precision_cholesky_factor, precision_cholesky_factor, diff
)

# Compute the log probability density
log_prob = -0.5 * (self.dim * keras.ops.log(2 * math.pi) + log_det_covariance + quadratic_term)

return log_prob

def sample(self, batch_shape: Shape, mean: Tensor, cov_chol: Tensor) -> Tensor:
def sample(self, batch_shape: Shape, mean: Tensor, precision_cholesky_factor: Tensor) -> Tensor:
"""
Generate samples from a multivariate Gaussian distribution.

Independent standard normal samples are transformed using the Cholesky factor of the covariance matrix
Independent standard normal samples are transformed using the Cholesky factor of the precision matrix
to generate correlated samples.

Parameters
Expand All @@ -114,32 +112,34 @@ def sample(self, batch_shape: Shape, mean: Tensor, cov_chol: Tensor) -> Tensor:
mean : Tensor
A tensor representing the mean of the multivariate Gaussian distribution.
Must have shape (batch_size, D), where D is the dimensionality of the distribution.
cov_chol : Tensor
A tensor representing a Cholesky factor of the covariance matrix of the multivariate Gaussian distribution.
precision_cholesky_factor : Tensor
A tensor representing the lower-triangular Cholesky factor of the precision matrix
of the multivariate Gaussian distribution.
Must have shape (batch_size, D, D), where D is the dimensionality.

Returns
-------
Tensor
A tensor of shape (batch_size, num_samples, D) containing the generated samples.
"""
covariance_cholesky_factor = keras.ops.inv(precision_cholesky_factor)
if len(batch_shape) == 1:
batch_shape = (1,) + tuple(batch_shape)
batch_size, num_samples = batch_shape
dim = keras.ops.shape(mean)[-1]
if keras.ops.shape(mean) != (batch_size, dim):
raise ValueError(f"mean must have shape (batch_size, {dim}), but got {keras.ops.shape(mean)}")

if keras.ops.shape(cov_chol) != (batch_size, dim, dim):
if keras.ops.shape(precision_cholesky_factor) != (batch_size, dim, dim):
raise ValueError(
f"covariance Cholesky factor must have shape (batch_size, {dim}, {dim}),"
f"but got {keras.ops.shape(cov_chol)}"
f"but got {keras.ops.shape(precision_cholesky_factor)}"
)

# Use Cholesky decomposition to generate samples
normal_samples = keras.random.normal((*batch_shape, dim))

scaled_normal = keras.ops.einsum("ijk,ilk->ilj", cov_chol, normal_samples)
scaled_normal = keras.ops.einsum("ijk,ilk->ilj", covariance_cholesky_factor, normal_samples)
samples = mean[:, None, :] + scaled_normal

return samples
2 changes: 1 addition & 1 deletion bayesflow/utils/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def stack(*items):
return keras.tree.map_structure(stack, *structures)


def fill_triangular_matrix(x: Tensor, upper: bool = False, positive_diag: bool = False):
def fill_triangular_matrix(x: Tensor, upper: bool = False):
"""
Reshapes a batch of matrix elements into a triangular matrix (either upper or lower).

Expand Down
4 changes: 0 additions & 4 deletions tests/test_approximators/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import pytest
import io
from contextlib import redirect_stdout
from tests.utils import check_approximator_multivariate_normal_score


@pytest.mark.skip(reason="not implemented")
Expand All @@ -20,9 +19,6 @@ def test_fit(amortizer, dataset):


def test_loss_progress(approximator, train_dataset, validation_dataset):
# as long as MultivariateNormalScore is unstable, skip fit progress test
check_approximator_multivariate_normal_score(approximator)

approximator.compile(optimizer="AdamW")
num_epochs = 3

Expand Down
4 changes: 1 addition & 3 deletions tests/test_approximators/test_log_prob.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import keras
import numpy as np
from tests.utils import check_combination_simulator_adapter, check_approximator_multivariate_normal_score
from tests.utils import check_combination_simulator_adapter


def test_approximator_log_prob(approximator, simulator, batch_size, adapter):
check_combination_simulator_adapter(simulator, adapter)
# as long as MultivariateNormalScore is unstable, skip
check_approximator_multivariate_normal_score(approximator)

num_batches = 4
data = simulator.sample((num_batches * batch_size,))
Expand Down
4 changes: 1 addition & 3 deletions tests/test_approximators/test_sample.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import keras
from tests.utils import check_combination_simulator_adapter, check_approximator_multivariate_normal_score
from tests.utils import check_combination_simulator_adapter


def test_approximator_sample(approximator, simulator, batch_size, adapter):
check_combination_simulator_adapter(simulator, adapter)
# as long as MultivariateNormalScore is unstable, skip
check_approximator_multivariate_normal_score(approximator)

num_batches = 4
data = simulator.sample((num_batches * batch_size,))
Expand Down
16 changes: 14 additions & 2 deletions tests/test_networks/test_standardization.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import numpy as np
import keras

Expand Down Expand Up @@ -156,9 +157,11 @@ def test_transformation_type_both_sides_scale():
np.testing.assert_allclose(cov_input, cov_standardized_and_recovered, atol=1e-4)


def test_transformation_type_left_side_scale():
@pytest.mark.parametrize("transformation_type", ["left_side_scale", "right_side_scale_inverse"])
def test_transformation_type_one_side_scale(transformation_type):
# Fix a known covariance and mean in original (not standardized space)
covariance = np.array([[1, 0.5], [0.5, 2.0]], dtype="float32")

mean = np.array([1, 10], dtype="float32")

# Generate samples
Expand All @@ -177,16 +180,25 @@ def test_transformation_type_left_side_scale():
cov_standardized = np.cov(keras.ops.convert_to_numpy(standardized), rowvar=False)
cov_standardized = keras.ops.convert_to_tensor(cov_standardized)
chol_standardized = keras.ops.cholesky(cov_standardized) # (dim, dim)

# We test the right_side_scale_inverse transformation by backtransforming a precision chol factor
# instead of a covariance chol factor.
if "inverse" in transformation_type:
chol_standardized = keras.ops.inv(chol_standardized)

# Inverse standardization of covariance matrix in standardized space
chol_standardized_and_recovered = layer(
chol_standardized, stage="inference", forward=False, transformation_type="left_side_scale"
chol_standardized, stage="inference", forward=False, transformation_type=transformation_type
)

random_input = keras.ops.convert_to_numpy(random_input)
chol_standardized_and_recovered = keras.ops.convert_to_numpy(chol_standardized_and_recovered)
cov_input = np.cov(random_input, rowvar=False)
chol_input = np.linalg.cholesky(cov_input)

if "inverse" in transformation_type:
chol_input = np.linalg.inv(chol_input)

np.testing.assert_allclose(chol_input, chol_standardized_and_recovered, atol=1e-4)


Expand Down
10 changes: 0 additions & 10 deletions tests/utils/check_combinations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,3 @@ def check_combination_simulator_adapter(simulator, adapter):
# to be used as sample weight, no error is raised currently.
# Don't use this fixture combination for further tests.
pytest.skip(reason="Do not use this fixture combination for further tests") # TODO: better reason


def check_approximator_multivariate_normal_score(approximator):
from bayesflow.approximators import PointApproximator
from bayesflow.scores import MultivariateNormalScore

if isinstance(approximator, PointApproximator):
for score in approximator.inference_network.scores.values():
if isinstance(score, MultivariateNormalScore):
pytest.skip(reason="MultivariateNormalScore is unstable")