Skip to content

Commit a8e9a72

Browse files
committed
Restore PositiveDefinite link with deprecation warning
1 parent b43f1cc commit a8e9a72

File tree

2 files changed

+54
-0
lines changed

2 files changed

+54
-0
lines changed

bayesflow/links/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .ordered import Ordered
44
from .ordered_quantiles import OrderedQuantiles
55
from .cholesky_factor import CholeskyFactor
6+
from .positive_definite import PositiveDefinite
67

78
from ..utils._docs import _add_imports_to_all
89

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import keras
2+
3+
from bayesflow.types import Tensor
4+
from bayesflow.utils import layer_kwargs, fill_triangular_matrix
5+
from bayesflow.utils.serialization import serializable
6+
7+
from warnings import warn
8+
9+
10+
@serializable("bayesflow.links")
11+
class PositiveDefinite(keras.Layer):
12+
"""Activation function to link from flat elements of a lower triangular matrix to a positive definite matrix."""
13+
14+
def __init__(self, **kwargs):
15+
super().__init__(**layer_kwargs(kwargs))
16+
self.built = True
17+
18+
warn(
19+
"This class is deprecated. It was replaced by bayesflow.links.CholeskyFactor.",
20+
DeprecationWarning,
21+
stacklevel=2,
22+
)
23+
24+
def call(self, inputs: Tensor) -> Tensor:
25+
# Build cholesky factor from inputs
26+
L = fill_triangular_matrix(inputs, positive_diag=True)
27+
28+
# calculate positive definite matrix from cholesky factors
29+
psd = keras.ops.matmul(
30+
L,
31+
keras.ops.moveaxis(L, -2, -1), # L transposed
32+
)
33+
return psd
34+
35+
def compute_output_shape(self, input_shape):
36+
m = input_shape[-1]
37+
n = int((0.25 + 2.0 * m) ** 0.5 - 0.5)
38+
return input_shape[:-1] + (n, n)
39+
40+
def compute_input_shape(self, output_shape):
41+
"""
42+
Returns the shape of parameterization of a cholesky factor triangular matrix.
43+
44+
There are m nonzero elements of a lower triangular nxn matrix with m = n * (n + 1) / 2.
45+
46+
Example
47+
-------
48+
>>> PositiveDefinite().compute_output_shape((None, 3, 3))
49+
6
50+
"""
51+
n = output_shape[-1]
52+
m = int(n * (n + 1) / 2)
53+
return output_shape[:-2] + (m,)

0 commit comments

Comments
 (0)