diff --git a/bayesflow/networks/standardization/standardization.py b/bayesflow/networks/standardization/standardization.py index 9dfdf2cb2..a3fe9cbdb 100644 --- a/bayesflow/networks/standardization/standardization.py +++ b/bayesflow/networks/standardization/standardization.py @@ -117,6 +117,9 @@ def call( case "left_side_scale": # x_ij = sigma_i * x_ij' out = val * keras.ops.moveaxis(std, -1, -2) + case "right_side_scale_inverse": + # x_ij = x_ij' / sigma_j + out = val / std case _: out = val diff --git a/bayesflow/scores/multivariate_normal_score.py b/bayesflow/scores/multivariate_normal_score.py index 71d45d4b2..26d7c7b1d 100644 --- a/bayesflow/scores/multivariate_normal_score.py +++ b/bayesflow/scores/multivariate_normal_score.py @@ -13,13 +13,14 @@ class MultivariateNormalScore(ParametricDistributionScore): r""":math:`S(\hat p_{\mu, \Sigma}, \theta; k) = -\log( \mathcal N (\theta; \mu, \Sigma))` - Scores a predicted mean and (Cholesky factor of the) covariance matrix with the log-score of the probability - of the materialized value. + Scores a predicted mean and lower-triangular Cholesky factor :math:`L` of the precision matrix :math:`P` + with the log-score of the probability of the materialized value. The precision matrix is + the inverse of the covariance matrix, :math:`L^T L = P = \Sigma^{-1}`. """ - NOT_TRANSFORMING_LIKE_VECTOR_WARNING = ("cov_chol",) + NOT_TRANSFORMING_LIKE_VECTOR_WARNING = ("precision_cholesky_factor",) """ - Marks head for covariance matrix Cholesky factor as an exception for adapter transformations. + Marks head for precision matrix Cholesky factor as an exception for adapter transformations. This variable contains names of prediction heads that should lead to a warning when the adapter is applied in inverse direction to them. @@ -27,13 +28,11 @@ class MultivariateNormalScore(ParametricDistributionScore): For more information see :py:class:`ScoringRule`. """ - TRANSFORMATION_TYPE: dict[str, str] = {"cov_chol": "left_side_scale"} + TRANSFORMATION_TYPE: dict[str, str] = {"precision_cholesky_factor": "right_side_scale_inverse"} """ - Marks covariance Cholesky factor head to handle de-standardization as for covariant rank-(0,2) tensors. + Marks precision Cholesky factor head to handle de-standardization appropriately. - The appropriate inverse of the standardization operation is - - x_ij = sigma_i * x_ij'. + See :py:class:`bayesflow.networks.Standardization` for more information on supported de-standardization options. For the mean head the default ("location_scale") is not overridden. """ @@ -42,7 +41,7 @@ def __init__(self, dim: int = None, links: dict = None, **kwargs): super().__init__(links=links, **kwargs) self.dim = dim - self.links = links or {"cov_chol": CholeskyFactor()} + self.links = links or {"precision_cholesky_factor": CholeskyFactor()} self.config = {"dim": dim} @@ -52,16 +51,16 @@ def get_config(self): def get_head_shapes_from_target_shape(self, target_shape: Shape) -> dict[str, Shape]: self.dim = target_shape[-1] - return dict(mean=(self.dim,), cov_chol=(self.dim, self.dim)) + return dict(mean=(self.dim,), precision_cholesky_factor=(self.dim, self.dim)) - def log_prob(self, x: Tensor, mean: Tensor, cov_chol: Tensor) -> Tensor: + def log_prob(self, x: Tensor, mean: Tensor, precision_cholesky_factor: Tensor) -> Tensor: """ Compute the log probability density of a multivariate Gaussian distribution. This function calculates the log probability density for each sample in `x` under a - multivariate Gaussian distribution with the given `mean` and `cov_chol`. + multivariate Gaussian distribution with the given `mean` and `precision_cholesky_factor`. - The computation includes the determinant of the covariance matrix, its inverse, and the quadratic + The computation includes the determinant of the precision matrix, its inverse, and the quadratic form in the exponential term of the Gaussian density function. Parameters @@ -71,8 +70,9 @@ def log_prob(self, x: Tensor, mean: Tensor, cov_chol: Tensor) -> Tensor: The shape should be compatible with broadcasting against `mean`. mean : Tensor A tensor representing the mean of the multivariate Gaussian distribution. - covariance : Tensor - A tensor representing the covariance matrix of the multivariate Gaussian distribution. + precision_cholesky_factor : Tensor + A tensor representing the lower-triangular Cholesky factor of the precision matrix + of the multivariate Gaussian distribution. Returns ------- @@ -82,29 +82,27 @@ def log_prob(self, x: Tensor, mean: Tensor, cov_chol: Tensor) -> Tensor: """ diff = x - mean - # Calculate precision from Cholesky factors of covariance matrix - cov_chol_inv = keras.ops.inv(cov_chol) - precision = keras.ops.matmul( - keras.ops.swapaxes(cov_chol_inv, -2, -1), - cov_chol_inv, - ) - # Compute log determinant, exploiting Cholesky factors - log_det_covariance = keras.ops.log(keras.ops.prod(keras.ops.diagonal(cov_chol, axis1=1, axis2=2), axis=1)) * 2 + log_det_covariance = -2 * keras.ops.sum( + keras.ops.log(keras.ops.diagonal(precision_cholesky_factor, axis1=1, axis2=2)), axis=1 + ) - # Compute the quadratic term in the exponential of the multivariate Gaussian - quadratic_term = keras.ops.einsum("...i,...ij,...j->...", diff, precision, diff) + # Compute the quadratic term in the exponential of the multivariate Gaussian from Cholesky factors + # diff^T * precision_cholesky_factor^T * precision_cholesky_factor * diff + quadratic_term = keras.ops.einsum( + "...i,...ji,...jk,...k->...", diff, precision_cholesky_factor, precision_cholesky_factor, diff + ) # Compute the log probability density log_prob = -0.5 * (self.dim * keras.ops.log(2 * math.pi) + log_det_covariance + quadratic_term) return log_prob - def sample(self, batch_shape: Shape, mean: Tensor, cov_chol: Tensor) -> Tensor: + def sample(self, batch_shape: Shape, mean: Tensor, precision_cholesky_factor: Tensor) -> Tensor: """ Generate samples from a multivariate Gaussian distribution. - Independent standard normal samples are transformed using the Cholesky factor of the covariance matrix + Independent standard normal samples are transformed using the Cholesky factor of the precision matrix to generate correlated samples. Parameters @@ -114,8 +112,9 @@ def sample(self, batch_shape: Shape, mean: Tensor, cov_chol: Tensor) -> Tensor: mean : Tensor A tensor representing the mean of the multivariate Gaussian distribution. Must have shape (batch_size, D), where D is the dimensionality of the distribution. - cov_chol : Tensor - A tensor representing a Cholesky factor of the covariance matrix of the multivariate Gaussian distribution. + precision_cholesky_factor : Tensor + A tensor representing the lower-triangular Cholesky factor of the precision matrix + of the multivariate Gaussian distribution. Must have shape (batch_size, D, D), where D is the dimensionality. Returns @@ -123,6 +122,7 @@ def sample(self, batch_shape: Shape, mean: Tensor, cov_chol: Tensor) -> Tensor: Tensor A tensor of shape (batch_size, num_samples, D) containing the generated samples. """ + covariance_cholesky_factor = keras.ops.inv(precision_cholesky_factor) if len(batch_shape) == 1: batch_shape = (1,) + tuple(batch_shape) batch_size, num_samples = batch_shape @@ -130,16 +130,16 @@ def sample(self, batch_shape: Shape, mean: Tensor, cov_chol: Tensor) -> Tensor: if keras.ops.shape(mean) != (batch_size, dim): raise ValueError(f"mean must have shape (batch_size, {dim}), but got {keras.ops.shape(mean)}") - if keras.ops.shape(cov_chol) != (batch_size, dim, dim): + if keras.ops.shape(precision_cholesky_factor) != (batch_size, dim, dim): raise ValueError( f"covariance Cholesky factor must have shape (batch_size, {dim}, {dim})," - f"but got {keras.ops.shape(cov_chol)}" + f"but got {keras.ops.shape(precision_cholesky_factor)}" ) # Use Cholesky decomposition to generate samples normal_samples = keras.random.normal((*batch_shape, dim)) - scaled_normal = keras.ops.einsum("ijk,ilk->ilj", cov_chol, normal_samples) + scaled_normal = keras.ops.einsum("ijk,ilk->ilj", covariance_cholesky_factor, normal_samples) samples = mean[:, None, :] + scaled_normal return samples diff --git a/bayesflow/utils/tensor_utils.py b/bayesflow/utils/tensor_utils.py index 7fa963fa6..4fa113682 100644 --- a/bayesflow/utils/tensor_utils.py +++ b/bayesflow/utils/tensor_utils.py @@ -311,7 +311,7 @@ def stack(*items): return keras.tree.map_structure(stack, *structures) -def fill_triangular_matrix(x: Tensor, upper: bool = False, positive_diag: bool = False): +def fill_triangular_matrix(x: Tensor, upper: bool = False): """ Reshapes a batch of matrix elements into a triangular matrix (either upper or lower). diff --git a/tests/test_approximators/test_fit.py b/tests/test_approximators/test_fit.py index b561efb77..27d4716c4 100644 --- a/tests/test_approximators/test_fit.py +++ b/tests/test_approximators/test_fit.py @@ -3,7 +3,6 @@ import pytest import io from contextlib import redirect_stdout -from tests.utils import check_approximator_multivariate_normal_score @pytest.mark.skip(reason="not implemented") @@ -20,9 +19,6 @@ def test_fit(amortizer, dataset): def test_loss_progress(approximator, train_dataset, validation_dataset): - # as long as MultivariateNormalScore is unstable, skip fit progress test - check_approximator_multivariate_normal_score(approximator) - approximator.compile(optimizer="AdamW") num_epochs = 3 diff --git a/tests/test_approximators/test_log_prob.py b/tests/test_approximators/test_log_prob.py index 8cfbb2fe6..9c96cdeb6 100644 --- a/tests/test_approximators/test_log_prob.py +++ b/tests/test_approximators/test_log_prob.py @@ -1,12 +1,10 @@ import keras import numpy as np -from tests.utils import check_combination_simulator_adapter, check_approximator_multivariate_normal_score +from tests.utils import check_combination_simulator_adapter def test_approximator_log_prob(approximator, simulator, batch_size, adapter): check_combination_simulator_adapter(simulator, adapter) - # as long as MultivariateNormalScore is unstable, skip - check_approximator_multivariate_normal_score(approximator) num_batches = 4 data = simulator.sample((num_batches * batch_size,)) diff --git a/tests/test_approximators/test_sample.py b/tests/test_approximators/test_sample.py index d7c2a3bcf..c62ffc581 100644 --- a/tests/test_approximators/test_sample.py +++ b/tests/test_approximators/test_sample.py @@ -1,11 +1,9 @@ import keras -from tests.utils import check_combination_simulator_adapter, check_approximator_multivariate_normal_score +from tests.utils import check_combination_simulator_adapter def test_approximator_sample(approximator, simulator, batch_size, adapter): check_combination_simulator_adapter(simulator, adapter) - # as long as MultivariateNormalScore is unstable, skip - check_approximator_multivariate_normal_score(approximator) num_batches = 4 data = simulator.sample((num_batches * batch_size,)) diff --git a/tests/test_networks/test_standardization.py b/tests/test_networks/test_standardization.py index 86881a384..863d977ea 100644 --- a/tests/test_networks/test_standardization.py +++ b/tests/test_networks/test_standardization.py @@ -1,3 +1,4 @@ +import pytest import numpy as np import keras @@ -156,9 +157,11 @@ def test_transformation_type_both_sides_scale(): np.testing.assert_allclose(cov_input, cov_standardized_and_recovered, atol=1e-4) -def test_transformation_type_left_side_scale(): +@pytest.mark.parametrize("transformation_type", ["left_side_scale", "right_side_scale_inverse"]) +def test_transformation_type_one_side_scale(transformation_type): # Fix a known covariance and mean in original (not standardized space) covariance = np.array([[1, 0.5], [0.5, 2.0]], dtype="float32") + mean = np.array([1, 10], dtype="float32") # Generate samples @@ -177,9 +180,15 @@ def test_transformation_type_left_side_scale(): cov_standardized = np.cov(keras.ops.convert_to_numpy(standardized), rowvar=False) cov_standardized = keras.ops.convert_to_tensor(cov_standardized) chol_standardized = keras.ops.cholesky(cov_standardized) # (dim, dim) + + # We test the right_side_scale_inverse transformation by backtransforming a precision chol factor + # instead of a covariance chol factor. + if "inverse" in transformation_type: + chol_standardized = keras.ops.inv(chol_standardized) + # Inverse standardization of covariance matrix in standardized space chol_standardized_and_recovered = layer( - chol_standardized, stage="inference", forward=False, transformation_type="left_side_scale" + chol_standardized, stage="inference", forward=False, transformation_type=transformation_type ) random_input = keras.ops.convert_to_numpy(random_input) @@ -187,6 +196,9 @@ def test_transformation_type_left_side_scale(): cov_input = np.cov(random_input, rowvar=False) chol_input = np.linalg.cholesky(cov_input) + if "inverse" in transformation_type: + chol_input = np.linalg.inv(chol_input) + np.testing.assert_allclose(chol_input, chol_standardized_and_recovered, atol=1e-4) diff --git a/tests/utils/check_combinations.py b/tests/utils/check_combinations.py index 8565703c8..1be5b1b5a 100644 --- a/tests/utils/check_combinations.py +++ b/tests/utils/check_combinations.py @@ -19,13 +19,3 @@ def check_combination_simulator_adapter(simulator, adapter): # to be used as sample weight, no error is raised currently. # Don't use this fixture combination for further tests. pytest.skip(reason="Do not use this fixture combination for further tests") # TODO: better reason - - -def check_approximator_multivariate_normal_score(approximator): - from bayesflow.approximators import PointApproximator - from bayesflow.scores import MultivariateNormalScore - - if isinstance(approximator, PointApproximator): - for score in approximator.inference_network.scores.values(): - if isinstance(score, MultivariateNormalScore): - pytest.skip(reason="MultivariateNormalScore is unstable")