Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion bayesflow/approximators/point_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
)

from bayesflow.types import Tensor
from bayesflow.utils import filter_kwargs, split_arrays, squeeze_inner_estimates_dict
from bayesflow.utils import filter_kwargs, split_arrays, squeeze_inner_estimates_dict, logging
from .continuous_approximator import ContinuousApproximator


Expand Down Expand Up @@ -119,6 +119,7 @@ def sample(
def _prepare_conditions(self, conditions: dict[str, np.ndarray], **kwargs) -> dict[str, Tensor]:
"""Adapts and converts the conditions to tensors."""
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
conditions.pop("inference_variables", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could add this function to the ContinuousApproximator, if it is identical between it and the Point Approximator

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! This and similar refactoring of the ContinuousApproximator is a good idea (but I would keep them out of this PR).
There is also the option of moving the conversion to tensor into the adapter. Possibly with an optional bool flag convert_to_tensor that is by default False.

return keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)

def _apply_inverse_adapter_to_estimates(
Expand All @@ -130,6 +131,12 @@ def _apply_inverse_adapter_to_estimates(
for score_key, score_val in estimates.items():
processed[score_key] = {}
for head_key, estimate in score_val.items():
if head_key in self.inference_network.scores[score_key].not_transforming_like_vector:
logging.warning(
f"Estimate '{score_key}.{head_key}' is marked to not transform like a vector. "
f"It was treated like a vector by the adapter. Handle '{head_key}' estimates with care."
)

adapted = self.adapter(
{"inference_variables": estimate},
inverse=True,
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/links/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from .ordered import Ordered
from .ordered_quantiles import OrderedQuantiles
from .positive_semi_definite import PositiveSemiDefinite
from .positive_definite import PositiveDefinite

from ..utils._docs import _add_imports_to_all

Expand Down
47 changes: 47 additions & 0 deletions bayesflow/links/positive_definite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import keras

# import numpy as np
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Tensor
from bayesflow.utils import keras_kwargs, fill_triangular_matrix


@serializable(package="bayesflow.links")
class PositiveDefinite(keras.Layer):
"""Activation function to link from flat elements of a lower triangular matrix to a positive definite matrix."""

def __init__(self, **kwargs):
super().__init__(**keras_kwargs(kwargs))
self.built = True

def call(self, inputs: Tensor) -> Tensor:
# Build cholesky factor from inputs
L = fill_triangular_matrix(inputs, positive_diag=True)

# calculate positive definite matrix from cholesky factors
psd = keras.ops.matmul(
L,
keras.ops.moveaxis(L, -2, -1), # L transposed
)
return psd

def compute_output_shape(self, input_shape):
m = input_shape[-1]
n = int((0.25 + 2.0 * m) ** 0.5 - 0.5)
return input_shape[:-1] + (n, n)

def compute_input_shape(self, output_shape):
"""
Returns the shape of parameterization of a cholesky factor triangular matrix.
There are m nonzero elements of a lower triangular nxn matrix with m = n * (n + 1) / 2.
Example
-------
>>> PositiveDefinite().compute_output_shape((None, 3, 3))
6
"""
n = output_shape[-1]
m = int(n * (n + 1) / 2)
return output_shape[:-2] + (m,)
20 changes: 0 additions & 20 deletions bayesflow/links/positive_semi_definite.py

This file was deleted.

2 changes: 1 addition & 1 deletion bayesflow/networks/point_inference_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def call(
if xz is None and not self.built:
raise ValueError("Cannot build inference network without inference variables.")
if conditions is None: # unconditional estimation uses a fixed input vector
conditions = keras.ops.convert_to_tensor([[1.0]], dtype=keras.ops.dtype(xz))
conditions = keras.ops.convert_to_tensor([[1.0]])

# pass conditions to the shared subnet
output = self.subnet(conditions, training=training)
Expand Down
19 changes: 11 additions & 8 deletions bayesflow/scores/multivariate_normal_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Shape, Tensor
from bayesflow.links import PositiveSemiDefinite
from bayesflow.utils import logging
from bayesflow.links import PositiveDefinite

from .parametric_distribution_score import ParametricDistributionScore

Expand All @@ -21,10 +20,12 @@ def __init__(self, dim: int = None, links: dict = None, **kwargs):
super().__init__(links=links, **kwargs)

self.dim = dim
self.links = links or {"covariance": PositiveSemiDefinite()}
self.config = {"dim": dim}
self.links = links or {"covariance": PositiveDefinite()}

# mark head for covariance matrix as an exception for adapter transformations
self.not_transforming_like_vector = ["covariance"]

logging.warning("MultivariateNormalScore is unstable.")
self.config = {"dim": dim}

def get_config(self):
base_config = super().get_config()
Expand Down Expand Up @@ -60,12 +61,12 @@ def log_prob(self, x: Tensor, mean: Tensor, covariance: Tensor) -> Tensor:
A tensor containing the log probability densities for each sample in `x` under the
given Gaussian distribution.
"""
diff = x[:, None, :] - mean
inv_covariance = keras.ops.inv(covariance)
diff = x - mean
precision = keras.ops.inv(covariance)
log_det_covariance = keras.ops.slogdet(covariance)[1] # Only take the log of the determinant part

# Compute the quadratic term in the exponential of the multivariate Gaussian
quadratic_term = keras.ops.einsum("...i,...ij,...j->...", diff, inv_covariance, diff)
quadratic_term = keras.ops.einsum("...i,...ij,...j->...", diff, precision, diff)

# Compute the log probability density
log_prob = -0.5 * (self.dim * keras.ops.log(2 * math.pi) + log_det_covariance + quadratic_term)
Expand Down Expand Up @@ -97,6 +98,8 @@ def sample(self, batch_shape: Shape, mean: Tensor, covariance: Tensor) -> Tensor
Tensor
A tensor of shape (batch_size, num_samples, D) containing the generated samples.
"""
if len(batch_shape) == 1:
batch_shape = (1,) + batch_shape
batch_size, num_samples = batch_shape
dim = keras.ops.shape(mean)[-1]
if keras.ops.shape(mean) != (batch_size, dim):
Expand Down
3 changes: 1 addition & 2 deletions bayesflow/scores/parametric_distribution_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,4 @@ def score(self, estimates: dict[str, Tensor], targets: Tensor, weights: Tensor =
"""
scores = -self.log_prob(x=targets, **estimates)
score = self.aggregate(scores, weights)
# multipy to mitigate instability due to relatively high values of parametric score
return score * 0.01
return score
21 changes: 16 additions & 5 deletions bayesflow/scores/scoring_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def __init__(
self.subnets_kwargs = subnets_kwargs or {}
self.links = links or {}

self.not_transforming_like_vector = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a comment documenting what this should contain, so people who want to set this in a subclass know what to use

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, I would also like to know if we can handle this in a better way than a special attribute.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want this to be variable within a class, i.e., does this have to be an instance variable? If not, a constant class variable might be an option, like

class ScoringRule:
    #: This variable lists ... (the #: syntax should allow sphinx to parse this)
    NOT_TRANSFORMING_LIKE_VECTOR = tuple()  # use immutable type tuple instead of list

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea to use a class variable for this! Implemented it in d87b0b9.

I added documenting comments as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding a better way of handling this in general:

There is a discussion in #304 about the long term plans for special estimators and how they play with adapter transformations.
In my opinion, what we have here suffices for now and is a reasonable first step.


self.config = {"subnets_kwargs": self.subnets_kwargs}

def get_config(self):
Expand Down Expand Up @@ -95,14 +97,14 @@ def get_link(self, key: str) -> keras.Layer:
else:
return self.links[key]

def get_head(self, key: str, shape: Shape) -> keras.Sequential:
def get_head(self, key: str, output_shape: Shape) -> keras.Sequential:
"""For a specified head key and shape, request corresponding head network.

Parameters
----------
key : str
Name of head for which to request a link.
shape: Shape
output_shape: Shape
The necessary shape for the point estimators.

Returns
Expand All @@ -111,10 +113,19 @@ def get_head(self, key: str, shape: Shape) -> keras.Sequential:
Head network consisting of a learnable projection, a reshape and a link operation
to parameterize estimates.
"""
subnet = self.get_subnet(key)
dense = keras.layers.Dense(units=math.prod(shape))
reshape = keras.layers.Reshape(target_shape=shape)
# initialize head components back to front
link = self.get_link(key)

# link input shape can differ from output shape
if hasattr(link, "compute_input_shape"):
link_input_shape = link.compute_input_shape(output_shape)
else:
link_input_shape = output_shape

reshape = keras.layers.Reshape(target_shape=link_input_shape)
dense = keras.layers.Dense(units=math.prod(link_input_shape))
subnet = self.get_subnet(key)

return keras.Sequential([subnet, dense, reshape, link])

def score(self, estimates: dict[str, Tensor], targets: Tensor, weights: Tensor) -> Tensor:
Expand Down
1 change: 1 addition & 0 deletions bayesflow/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
tile_axis,
tree_concatenate,
tree_stack,
fill_triangular_matrix,
)
from .validators import check_lengths_same
from .workflow_utils import find_inference_network, find_summary_network
Expand Down
75 changes: 75 additions & 0 deletions bayesflow/utils/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,78 @@ def stack(*items):
return keras.ops.stack(items, axis=axis)

return keras.tree.map_structure(stack, *structures)


def fill_triangular_matrix(x: Tensor, upper: bool = False, positive_diag: bool = False):
"""
Reshapes a batch of matrix elements into a triangular matrix (either upper or lower).

Note: If final axis has length 1, this simply reshapes to (batch_size, 1, 1) and optionally applies softplus.

Parameters
----------
x : Tensor of shape (batch_size, m)
Batch of flattened nonzero matrix elements for triangular matrix.
upper : bool
Return upper triangular matrix if True, else lower triangular matrix. Default is False.
positive_diag : bool
Whether to apply a softplus operation to diagonal elements. Default is False.

Returns
-------
Tensor of shape (batch_size, n, n)
Batch of triangular matrices with m = n * (n + 1) / 2 unique nonzero elements.

Raises
------
ValueError
If provided nonzero elements do not correspond to possible triangular matrix shape
(n,n) with n = sqrt( 1/4 + 2 * m) - 1/2 due to m = n * (n + 1) / 2.
"""
batch_shape = x.shape[:-1]
m = x.shape[-1]

if m == 1:
y = keras.ops.reshape(x, (-1, 1, 1))
if positive_diag:
y = keras.activations.softplus(y)
return y

# Calculate matrix shape
n = (0.25 + 2 * m) ** 0.5 - 0.5
if not np.isclose(np.floor(n), n):
raise ValueError(f"Input right-most shape ({m}) does not correspond to a triangular matrix.")
else:
n = int(n)

# Trick: Create triangular matrix by concatenating with a flipped version of its tail, then reshape.
x_tail = keras.ops.take(x, indices=list(range((m - (n**2 - m)), x.shape[-1])), axis=-1)
if not upper:
y = keras.ops.concatenate([x_tail, keras.ops.flip(x, axis=-1)], axis=len(batch_shape))
y = keras.ops.reshape(y, (-1, n, n))
y = keras.ops.tril(y)

if positive_diag:
y_offdiag = keras.ops.tril(y, k=-1)
# carve out diagonal, by setting upper and lower offdiagonals to zero
y_diag = keras.ops.tril(
keras.ops.triu(keras.activations.softplus(y)), # apply softplus to enforce positivity
)
y = y_diag + y_offdiag

else:
y = keras.ops.concatenate([x, keras.ops.flip(x_tail, axis=-1)], axis=len(batch_shape))
y = keras.ops.reshape(y, (-1, n, n))
y = keras.ops.triu(
y,
)

if positive_diag:
y_offdiag = keras.ops.triu(y, k=1)
# carve out diagonal, by setting upper and lower offdiagonals to zero
y_diag = keras.ops.tril(
keras.ops.triu(keras.activations.softplus(y)), # apply softplus to enforce positivity
)
y = y_diag + y_offdiag

return y
16 changes: 8 additions & 8 deletions tests/test_links/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def num_variables():

@pytest.fixture()
def generic_preactivation(batch_size):
return keras.ops.ones((batch_size, 4, 4))
return keras.ops.ones((batch_size, 6))


@pytest.fixture()
Expand All @@ -33,18 +33,18 @@ def ordered_quantiles():


@pytest.fixture()
def positive_semi_definite():
from bayesflow.links import PositiveSemiDefinite
def positive_definite():
from bayesflow.links import PositiveDefinite

return PositiveSemiDefinite()
return PositiveDefinite()


@pytest.fixture()
def linear():
return keras.layers.Activation("linear")


@pytest.fixture(params=["ordered", "ordered_quantiles", "positive_semi_definite", "linear"], scope="function")
@pytest.fixture(params=["ordered", "ordered_quantiles", "positive_definite", "linear"], scope="function")
def link(request):
return request.getfixturevalue(request.param)

Expand Down Expand Up @@ -84,6 +84,6 @@ def unordered(batch_size, num_quantiles, num_variables):
return keras.random.normal((batch_size, num_quantiles, num_variables))


@pytest.fixture()
def random_matrix_batch(batch_size, num_variables):
return keras.random.normal((batch_size, num_variables, num_variables))
# @pytest.fixture()
# def random_matrix_batch(batch_size, num_variables):
# return keras.random.normal((batch_size, num_variables, num_variables))
22 changes: 10 additions & 12 deletions tests/test_links/test_links.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,6 @@
import pytest


def test_link_output(link, generic_preactivation):
output_shape = link.compute_output_shape(generic_preactivation.shape)
output = link(generic_preactivation)

assert output_shape == output.shape


def test_invalid_shape_for_ordered_quantiles(ordered_quantiles, batch_size, num_quantiles, num_variables):
with pytest.raises(AssertionError) as excinfo:
ordered_quantiles.build((batch_size, batch_size, num_quantiles, num_variables))
Expand Down Expand Up @@ -59,16 +52,21 @@ def test_quantile_ordering(quantiles, unordered):
check_ordering(output, axis)


def test_positive_semi_definite(random_matrix_batch):
from bayesflow.links import PositiveSemiDefinite
def test_positive_definite(positive_definite, batch_size, num_variables):
input_shape = positive_definite.compute_input_shape((batch_size, num_variables, num_variables))

activation = PositiveSemiDefinite()
# Too strongly negative values lead to numerical instabilities -> reduce scale
random_preactivation = keras.random.normal(input_shape) * 0.1
output = positive_definite(random_preactivation)

output = activation(random_matrix_batch)
# Check if output is invertible
np.linalg.inv(output)

# Calculated eigenvalues to test for positive definiteness
output = keras.ops.convert_to_numpy(output)
eigenvalues = np.linalg.eig(output).eigenvalues

assert np.all(eigenvalues.real > 0) and np.all(np.isclose(eigenvalues.imag, 0)), (
f"output is not positive semi-definite: real={eigenvalues.real}, imag={eigenvalues.imag}"
f"output is not positive definite: min(real)={np.min(eigenvalues.real)}, "
f"max(abs(imag))={np.max(np.abs(eigenvalues.imag))}"
)
Loading