Skip to content

Commit a1c5d43

Browse files
gpapamakDistraxDev
authored andcommitted
Implement LowerUpperTriangularAffine as a composition of two TriangularAffines.
PiperOrigin-RevId: 421792044
1 parent da9ed56 commit a1c5d43

File tree

2 files changed

+38
-71
lines changed

2 files changed

+38
-71
lines changed

distrax/_src/bijectors/lower_upper_triangular_affine.py

Lines changed: 18 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,16 @@
1414
# ==============================================================================
1515
"""LU-decomposed affine bijector."""
1616

17-
from typing import Tuple
18-
1917
from distrax._src.bijectors import bijector as base
20-
from distrax._src.bijectors import unconstrained_affine
21-
import jax
18+
from distrax._src.bijectors import chain
19+
from distrax._src.bijectors import triangular_affine
2220
import jax.numpy as jnp
2321

2422

2523
Array = base.Array
2624

2725

28-
class LowerUpperTriangularAffine(base.Bijector):
26+
class LowerUpperTriangularAffine(chain.Chain):
2927
"""An affine bijector whose weight matrix is parameterized as A = LU.
3028
3129
This bijector is defined as `f(x) = Ax + b` where:
@@ -67,79 +65,33 @@ class docstring. Can also be a batch of matrices. If `matrix` is the
6765
generally not equal to the product `LU`.
6866
bias: the vector `b` in `LUx + b`. Can also be a batch of vectors.
6967
"""
70-
super().__init__(event_ndims_in=1, is_constant_jacobian=True)
71-
self._batch_shape = unconstrained_affine.common_batch_shape(matrix, bias)
72-
self._bias = bias
73-
74-
def compute_lu(matrix):
75-
# Lower-triangular matrix with ones on the diagonal.
76-
lower = jnp.eye(matrix.shape[-1]) + jnp.tril(matrix, -1)
77-
# Upper-triangular matrix.
78-
upper = jnp.triu(matrix)
79-
# Log absolute determinant.
80-
logdet = jnp.sum(jnp.log(jnp.abs(jnp.diag(matrix))))
81-
return lower, upper, logdet
82-
83-
compute_lu = jnp.vectorize(compute_lu, signature="(m,m)->(m,m),(m,m),()")
84-
self._lower, self._upper, self._logdet = compute_lu(matrix)
68+
if matrix.ndim < 2:
69+
raise ValueError(f"`matrix` must have at least 2 dimensions, got"
70+
f" {matrix.ndim}.")
71+
dim = matrix.shape[-1]
72+
# z = Ux
73+
self._upper_linear = triangular_affine.TriangularAffine(
74+
matrix, bias=jnp.zeros((dim,)), is_lower=False)
75+
# y = Lz + b
76+
lower = jnp.eye(dim) + jnp.tril(matrix, -1) # Replace diagonal with ones.
77+
self._lower_affine = triangular_affine.TriangularAffine(
78+
lower, bias, is_lower=True)
79+
super().__init__([self._lower_affine, self._upper_linear])
8580

8681
@property
8782
def lower(self) -> Array:
8883
"""The lower triangular matrix `L` with ones in the diagonal."""
89-
return self._lower
84+
return self._lower_affine.matrix
9085

9186
@property
9287
def upper(self) -> Array:
9388
"""The upper triangular matrix `U`."""
94-
return self._upper
89+
return self._upper_linear.matrix
9590

9691
@property
9792
def bias(self) -> Array:
9893
"""The shift `b` of the transformation."""
99-
return self._bias
100-
101-
def forward(self, x: Array) -> Array:
102-
"""Computes y = f(x)."""
103-
self._check_forward_input_shape(x)
104-
105-
def unbatched(single_x, lower, upper, bias):
106-
return lower @ (upper @ single_x) + bias
107-
108-
batched = jnp.vectorize(unbatched, signature="(m),(m,m),(m,m),(m)->(m)")
109-
return batched(x, self._lower, self._upper, self._bias)
110-
111-
def forward_log_det_jacobian(self, x: Array) -> Array:
112-
"""Computes log|det J(f)(x)|."""
113-
self._check_forward_input_shape(x)
114-
batch_shape = jax.lax.broadcast_shapes(self._batch_shape, x.shape[:-1])
115-
return jnp.broadcast_to(self._logdet, batch_shape)
116-
117-
def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]:
118-
"""Computes y = f(x) and log|det J(f)(x)|."""
119-
return self.forward(x), self.forward_log_det_jacobian(x)
120-
121-
def inverse(self, y: Array) -> Array:
122-
"""Computes x = f^{-1}(y)."""
123-
self._check_inverse_input_shape(y)
124-
125-
def unbatched(single_y, lower, upper, bias):
126-
x = single_y - bias
127-
x = jax.scipy.linalg.solve_triangular(
128-
lower, x, lower=True, unit_diagonal=True)
129-
x = jax.scipy.linalg.solve_triangular(
130-
upper, x, lower=False, unit_diagonal=False)
131-
return x
132-
133-
batched = jnp.vectorize(unbatched, signature="(m),(m,m),(m,m),(m)->(m)")
134-
return batched(y, self._lower, self._upper, self._bias)
135-
136-
def inverse_log_det_jacobian(self, y: Array) -> Array:
137-
"""Computes log|det J(f^{-1})(y)|."""
138-
return -self.forward_log_det_jacobian(y)
139-
140-
def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]:
141-
"""Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|."""
142-
return self.inverse(y), self.inverse_log_det_jacobian(y)
94+
return self._lower_affine.bias
14395

14496
def same_as(self, other: base.Bijector) -> bool:
14597
"""Returns True if this bijector is guaranteed to be the same as `other`."""

distrax/_src/bijectors/lower_upper_triangular_affine_test.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,26 @@ def test_jacobian_is_constant_property(self):
3535

3636
def test_properties(self):
3737
bijector = LowerUpperTriangularAffine(
38-
matrix=jnp.eye(4),
39-
bias=jnp.ones((4,)))
40-
np.testing.assert_allclose(bijector.lower, np.eye(4), atol=1e-6)
41-
np.testing.assert_allclose(bijector.upper, np.eye(4), atol=1e-6)
42-
np.testing.assert_allclose(bijector.bias, np.ones((4,)), atol=1e-6)
38+
matrix=jnp.array([[2., 3.], [4., 5.]]),
39+
bias=jnp.ones((2,)))
40+
np.testing.assert_allclose(
41+
bijector.lower, np.array([[1., 0.], [4., 1.]]), atol=1e-6)
42+
np.testing.assert_allclose(
43+
bijector.upper, np.array([[2., 3.], [0., 5.]]), atol=1e-6)
44+
np.testing.assert_allclose(bijector.bias, np.ones((2,)), atol=1e-6)
45+
46+
@parameterized.named_parameters(
47+
('matrix is 0d', {'matrix': np.zeros(()), 'bias': np.zeros((4,))}),
48+
('matrix is 1d', {'matrix': np.zeros((4,)), 'bias': np.zeros((4,))}),
49+
('bias is 0d', {'matrix': np.zeros((4, 4)), 'bias': np.zeros(())}),
50+
('matrix is not square',
51+
{'matrix': np.zeros((3, 4)), 'bias': np.zeros((4,))}),
52+
('matrix and bias shapes do not agree',
53+
{'matrix': np.zeros((4, 4)), 'bias': np.zeros((3,))}),
54+
)
55+
def test_raises_with_invalid_parameters(self, bij_params):
56+
with self.assertRaises(ValueError):
57+
LowerUpperTriangularAffine(**bij_params)
4358

4459
@chex.all_variants
4560
@parameterized.parameters(

0 commit comments

Comments
 (0)