|
14 | 14 | # ============================================================================== |
15 | 15 | """LU-decomposed affine bijector.""" |
16 | 16 |
|
17 | | -from typing import Tuple |
18 | | - |
19 | 17 | from distrax._src.bijectors import bijector as base |
20 | | -from distrax._src.bijectors import unconstrained_affine |
21 | | -import jax |
| 18 | +from distrax._src.bijectors import chain |
| 19 | +from distrax._src.bijectors import triangular_affine |
22 | 20 | import jax.numpy as jnp |
23 | 21 |
|
24 | 22 |
|
25 | 23 | Array = base.Array |
26 | 24 |
|
27 | 25 |
|
28 | | -class LowerUpperTriangularAffine(base.Bijector): |
| 26 | +class LowerUpperTriangularAffine(chain.Chain): |
29 | 27 | """An affine bijector whose weight matrix is parameterized as A = LU. |
30 | 28 |
|
31 | 29 | This bijector is defined as `f(x) = Ax + b` where: |
@@ -67,79 +65,33 @@ class docstring. Can also be a batch of matrices. If `matrix` is the |
67 | 65 | generally not equal to the product `LU`. |
68 | 66 | bias: the vector `b` in `LUx + b`. Can also be a batch of vectors. |
69 | 67 | """ |
70 | | - super().__init__(event_ndims_in=1, is_constant_jacobian=True) |
71 | | - self._batch_shape = unconstrained_affine.common_batch_shape(matrix, bias) |
72 | | - self._bias = bias |
73 | | - |
74 | | - def compute_lu(matrix): |
75 | | - # Lower-triangular matrix with ones on the diagonal. |
76 | | - lower = jnp.eye(matrix.shape[-1]) + jnp.tril(matrix, -1) |
77 | | - # Upper-triangular matrix. |
78 | | - upper = jnp.triu(matrix) |
79 | | - # Log absolute determinant. |
80 | | - logdet = jnp.sum(jnp.log(jnp.abs(jnp.diag(matrix)))) |
81 | | - return lower, upper, logdet |
82 | | - |
83 | | - compute_lu = jnp.vectorize(compute_lu, signature="(m,m)->(m,m),(m,m),()") |
84 | | - self._lower, self._upper, self._logdet = compute_lu(matrix) |
| 68 | + if matrix.ndim < 2: |
| 69 | + raise ValueError(f"`matrix` must have at least 2 dimensions, got" |
| 70 | + f" {matrix.ndim}.") |
| 71 | + dim = matrix.shape[-1] |
| 72 | + # z = Ux |
| 73 | + self._upper_linear = triangular_affine.TriangularAffine( |
| 74 | + matrix, bias=jnp.zeros((dim,)), is_lower=False) |
| 75 | + # y = Lz + b |
| 76 | + lower = jnp.eye(dim) + jnp.tril(matrix, -1) # Replace diagonal with ones. |
| 77 | + self._lower_affine = triangular_affine.TriangularAffine( |
| 78 | + lower, bias, is_lower=True) |
| 79 | + super().__init__([self._lower_affine, self._upper_linear]) |
85 | 80 |
|
86 | 81 | @property |
87 | 82 | def lower(self) -> Array: |
88 | 83 | """The lower triangular matrix `L` with ones in the diagonal.""" |
89 | | - return self._lower |
| 84 | + return self._lower_affine.matrix |
90 | 85 |
|
91 | 86 | @property |
92 | 87 | def upper(self) -> Array: |
93 | 88 | """The upper triangular matrix `U`.""" |
94 | | - return self._upper |
| 89 | + return self._upper_linear.matrix |
95 | 90 |
|
96 | 91 | @property |
97 | 92 | def bias(self) -> Array: |
98 | 93 | """The shift `b` of the transformation.""" |
99 | | - return self._bias |
100 | | - |
101 | | - def forward(self, x: Array) -> Array: |
102 | | - """Computes y = f(x).""" |
103 | | - self._check_forward_input_shape(x) |
104 | | - |
105 | | - def unbatched(single_x, lower, upper, bias): |
106 | | - return lower @ (upper @ single_x) + bias |
107 | | - |
108 | | - batched = jnp.vectorize(unbatched, signature="(m),(m,m),(m,m),(m)->(m)") |
109 | | - return batched(x, self._lower, self._upper, self._bias) |
110 | | - |
111 | | - def forward_log_det_jacobian(self, x: Array) -> Array: |
112 | | - """Computes log|det J(f)(x)|.""" |
113 | | - self._check_forward_input_shape(x) |
114 | | - batch_shape = jax.lax.broadcast_shapes(self._batch_shape, x.shape[:-1]) |
115 | | - return jnp.broadcast_to(self._logdet, batch_shape) |
116 | | - |
117 | | - def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]: |
118 | | - """Computes y = f(x) and log|det J(f)(x)|.""" |
119 | | - return self.forward(x), self.forward_log_det_jacobian(x) |
120 | | - |
121 | | - def inverse(self, y: Array) -> Array: |
122 | | - """Computes x = f^{-1}(y).""" |
123 | | - self._check_inverse_input_shape(y) |
124 | | - |
125 | | - def unbatched(single_y, lower, upper, bias): |
126 | | - x = single_y - bias |
127 | | - x = jax.scipy.linalg.solve_triangular( |
128 | | - lower, x, lower=True, unit_diagonal=True) |
129 | | - x = jax.scipy.linalg.solve_triangular( |
130 | | - upper, x, lower=False, unit_diagonal=False) |
131 | | - return x |
132 | | - |
133 | | - batched = jnp.vectorize(unbatched, signature="(m),(m,m),(m,m),(m)->(m)") |
134 | | - return batched(y, self._lower, self._upper, self._bias) |
135 | | - |
136 | | - def inverse_log_det_jacobian(self, y: Array) -> Array: |
137 | | - """Computes log|det J(f^{-1})(y)|.""" |
138 | | - return -self.forward_log_det_jacobian(y) |
139 | | - |
140 | | - def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]: |
141 | | - """Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.""" |
142 | | - return self.inverse(y), self.inverse_log_det_jacobian(y) |
| 94 | + return self._lower_affine.bias |
143 | 95 |
|
144 | 96 | def same_as(self, other: base.Bijector) -> bool: |
145 | 97 | """Returns True if this bijector is guaranteed to be the same as `other`.""" |
|
0 commit comments