Skip to content

Commit 7365f38

Browse files
Add CirculantNormal distribution. (#1988)
* Add `CirculantNormal` distribution. * Fix Jacobian of `RealFastFourierTransform`. * Add transform to pack and unpack Fourier coefficients. * Implement `CirculantNormal` using transformations. * Fix typo in skip reason for `test_generated_sample_distribution`. * Add `CirculantNormal` example notebook. * Reformat `circulant_gp` example notebook. * Fix keyword argument for `PackRealFastFourierCoefficientsTransform` test. * Fix batch shape for mean, covariance_row, covariance_matrix. * Add analytic KL divergence of diagonal `Normal` from `CirculantNormal`. * Fix sign bug in `_PositiveDefiniteCirculantVector` validity check. * Expand docstring and add references for `Circulant` distribution. * Fix shape promotion for `Circulant` distribution args. * Skip goodnes of fit test for `CirculantNormal`. * Expand docstring of `CirculantNormal`. * Update example and move legend. --------- Co-authored-by: Du Phan <[email protected]>
1 parent bdb3329 commit 7365f38

File tree

12 files changed

+673
-12
lines changed

12 files changed

+673
-12
lines changed

docs/source/distributions.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,14 @@ Chi2
136136
:show-inheritance:
137137
:member-order: bysource
138138

139+
CirculantNormal
140+
^^^^^^^^^^^^^^^
141+
.. autoclass:: numpyro.distributions.continuous.CirculantNormal
142+
:members:
143+
:undoc-members:
144+
:show-inheritance:
145+
:member-order: bysource
146+
139147
Dirichlet
140148
^^^^^^^^^
141149
.. autoclass:: numpyro.distributions.continuous.Dirichlet
@@ -998,6 +1006,14 @@ OrderedTransform
9981006
:show-inheritance:
9991007
:member-order: bysource
10001008

1009+
PackRealFastFourierCoefficientsTransform
1010+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1011+
.. autoclass:: numpyro.distributions.transforms.PackRealFastFourierCoefficientsTransform
1012+
:members:
1013+
:undoc-members:
1014+
:show-inheritance:
1015+
:member-order: bysource
1016+
10011017
PermuteTransform
10021018
^^^^^^^^^^^^^^^^
10031019
.. autoclass:: numpyro.distributions.transforms.PermuteTransform

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ NumPyro documentation
3939
tutorials/censoring
4040
tutorials/hsgp_example
4141
tutorials/other_samplers
42+
tutorials/circulant_gp
4243
tutorials/nnx_example
4344

4445
.. nbgallery::

notebooks/source/circulant_gp.ipynb

Lines changed: 319 additions & 0 deletions
Large diffs are not rendered by default.

numpyro/distributions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
BetaProportion,
2020
Cauchy,
2121
Chi2,
22+
CirculantNormal,
2223
Dirichlet,
2324
EulerMaruyama,
2425
Exponential,
@@ -132,6 +133,7 @@
132133
"CategoricalProbs",
133134
"Cauchy",
134135
"Chi2",
136+
"CirculantNormal",
135137
"Delta",
136138
"Dirichlet",
137139
"DirichletMultinomial",

numpyro/distributions/constraints.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
"nonnegative_integer",
4848
"positive",
4949
"positive_definite",
50+
"positive_definite_circulant_vector",
5051
"positive_semidefinite",
5152
"positive_integer",
5253
"real",
@@ -642,6 +643,19 @@ def feasible_like(self, prototype):
642643
)
643644

644645

646+
class _PositiveDefiniteCirculantVector(_SingletonConstraint):
647+
event_dim = 1
648+
649+
def __call__(self, x):
650+
jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy
651+
tol = 10 * jnp.finfo(x.dtype).eps
652+
rfft = jnp.fft.rfft(x)
653+
return (jnp.abs(rfft.imag) < tol) & (rfft.real > -tol)
654+
655+
def feasible_like(self, prototype):
656+
return jnp.zeros_like(prototype).at[..., 0].set(1.0)
657+
658+
645659
class _PositiveSemiDefinite(_SingletonConstraint):
646660
event_dim = 2
647661

@@ -792,6 +806,7 @@ def tree_flatten(self):
792806
ordered_vector = _OrderedVector()
793807
positive = _Positive()
794808
positive_definite = _PositiveDefinite()
809+
positive_definite_circulant_vector = _PositiveDefiniteCirculantVector()
795810
positive_semidefinite = _PositiveSemiDefinite()
796811
positive_integer = _IntegerPositive()
797812
positive_ordered_vector = _PositiveOrderedVector()

numpyro/distributions/continuous.py

Lines changed: 145 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
import jax.nn as nn
3434
import jax.numpy as jnp
3535
import jax.random as random
36-
from jax.scipy.linalg import cho_solve, solve_triangular
36+
from jax.scipy.linalg import cho_solve, solve_triangular, toeplitz
3737
from jax.scipy.special import (
3838
betaln,
3939
digamma,
@@ -59,12 +59,15 @@
5959
CholeskyTransform,
6060
CorrMatrixCholeskyTransform,
6161
ExpTransform,
62+
PackRealFastFourierCoefficientsTransform,
6263
PowerTransform,
64+
RealFastFourierTransform,
6365
RecursiveLinearTransform,
6466
SigmoidTransform,
6567
ZeroSumTransform,
6668
)
6769
from numpyro.distributions.util import (
70+
_reshape,
6871
add_diag,
6972
assert_one_of,
7073
betainc,
@@ -3068,3 +3071,144 @@ def entropy(self) -> ArrayLike:
30683071
return jnp.broadcast_to(
30693072
0.5 + 1.5 * jnp.euler_gamma + 0.5 * jnp.log(16 * jnp.pi), self.batch_shape
30703073
) + jnp.log(self.scale)
3074+
3075+
3076+
class CirculantNormal(TransformedDistribution):
3077+
r"""
3078+
Multivariate normal distribution with covariance matrix :math:`\mathbf{C}` that is
3079+
positive-definite and circulant [1], i.e., has periodic boundary conditions. The
3080+
density of a sample :math:`\mathbf{x}\in\mathbb{R}^n` is the standard multivariate
3081+
normal density
3082+
3083+
.. math::
3084+
3085+
p\left(\mathbf{x}\mid\boldsymbol{\mu},\mathbf{C}\right) =
3086+
\frac{\left(\mathrm{det}\,\mathbf{C}\right)^{-1/2}}{\left(2\pi\right)^{n / 2}}
3087+
\exp\left(-\frac{1}{2}\left(\mathbf{x}-\boldsymbol{\mu}\right)^\intercal
3088+
\mathbf{C}^{-1}\left(\mathbf{x}-\boldsymbol{\mu}\right)\right),
3089+
3090+
where :math:`\mathrm{det}` denotes the determinant and :math:`^\intercal` the
3091+
transpose. Circulant matrices can be diagnolized efficiently using the discrete
3092+
Fourier transform [1], allowing the log likelihood to be evaluated in
3093+
:math:`n \log n` time for :math:`n` observations [2].
3094+
3095+
:param loc: Mean of the distribution :math:`\boldsymbol{\mu}`.
3096+
:param covariance_row: First row of the circulant covariance matrix
3097+
:math:`\boldsymbol{C}`. Because of periodic boundary conditions, the covariance
3098+
matrix is fully determined by its first row (see
3099+
:func:`jax.scipy.linalg.toeplitz` for further details).
3100+
:param covariance_rfft: Real part of the real fast Fourier transform of
3101+
:code:`covariance_row`, the first row of the circulant covariance matrix
3102+
:math:`\boldsymbol{C}`.
3103+
3104+
**References:**
3105+
3106+
1. Wikipedia. (n.d.). Circulant matrix. Retrieved March 6, 2025, from
3107+
https://en.wikipedia.org/wiki/Circulant_matrix
3108+
2. Wood, A. T. A., & Chan, G. (1994). Simulation of Stationary Gaussian Processes in
3109+
:math:`\left[0, 1\right]^d`. *Journal of Computational and Graphical Statistics*,
3110+
3(4), 409--432. https://doi.org/10.1080/10618600.1994.10474655
3111+
"""
3112+
3113+
arg_constraints = {
3114+
"loc": constraints.real_vector,
3115+
"covariance_row": constraints.positive_definite_circulant_vector,
3116+
"covariance_rfft": constraints.independent(constraints.positive, 1),
3117+
}
3118+
support = constraints.real_vector
3119+
3120+
def __init__(
3121+
self,
3122+
loc: jnp.ndarray,
3123+
covariance_row: jnp.ndarray = None,
3124+
covariance_rfft: jnp.ndarray = None,
3125+
*,
3126+
validate_args=None,
3127+
) -> None:
3128+
# We demand a one-dimensional input, because we cannot determine the event shape
3129+
# if only the `covariance_rfft` is given.
3130+
assert jnp.ndim(loc) > 0, "Location parameter must have at least one dimension."
3131+
n = jnp.shape(loc)[-1]
3132+
n_rfft = n // 2 + 1
3133+
assert_one_of(covariance_row=covariance_row, covariance_rfft=covariance_rfft)
3134+
3135+
if covariance_rfft is None:
3136+
# Evaluate `covariance_rfft` if not provided and validate.
3137+
assert covariance_row.shape[-1] == n
3138+
loc, covariance_row = promote_shapes(loc, covariance_row)
3139+
covariance_rfft = jnp.fft.rfft(covariance_row).real
3140+
self.covariance_row = covariance_row
3141+
else:
3142+
# The `covariance_rfft` and `loc` are not promotable because the trailing
3143+
# dimension does not match. We manually retrieve the shapes and then
3144+
# promote.
3145+
loc_shape, covariance_rfft_shape = promote_shapes(
3146+
loc[..., 0], covariance_rfft[..., 0], return_shapes=True
3147+
)
3148+
loc = _reshape(loc, loc_shape + (n,))
3149+
covariance_rfft = _reshape(
3150+
covariance_rfft, covariance_rfft_shape + (n_rfft,)
3151+
)
3152+
3153+
self.loc = loc
3154+
self.covariance_rfft = covariance_rfft
3155+
3156+
# Construct the base distribution.
3157+
n_imag = n - n_rfft
3158+
assert self.covariance_rfft.shape[-1] == n_rfft
3159+
var_rfft = (n * covariance_rfft / 2).at[..., 0].mul(2)
3160+
if n % 2 == 0:
3161+
var_rfft = var_rfft.at[..., -1].mul(2)
3162+
var_rfft = jnp.concatenate([var_rfft, var_rfft[..., 1 : 1 + n_imag]], axis=-1)
3163+
assert var_rfft.shape[-1] == n
3164+
base_distribution = Normal(scale=jnp.sqrt(var_rfft)).to_event(1)
3165+
3166+
super().__init__(
3167+
base_distribution,
3168+
[
3169+
PackRealFastFourierCoefficientsTransform((n,)),
3170+
RealFastFourierTransform((n,)).inv,
3171+
AffineTransform(loc, scale=1.0),
3172+
],
3173+
validate_args=validate_args,
3174+
)
3175+
3176+
@property
3177+
def mean(self) -> jnp.ndarray:
3178+
return jnp.broadcast_to(self.loc, self.shape())
3179+
3180+
@lazy_property
3181+
def covariance_row(self) -> jnp.ndarray:
3182+
return jnp.fft.irfft(self.covariance_rfft, n=self.event_shape[-1])
3183+
3184+
@lazy_property
3185+
def covariance_matrix(self) -> jnp.ndarray:
3186+
*leading_shape, n = self.covariance_row.shape
3187+
if leading_shape:
3188+
# `toeplitz` flattens the input, and we need to broadcast manually.
3189+
(n,) = self.event_shape
3190+
return vmap(toeplitz)(self.covariance_row.reshape((-1, n))).reshape(
3191+
(*leading_shape, n, n)
3192+
)
3193+
else:
3194+
return toeplitz(self.covariance_row)
3195+
3196+
@lazy_property
3197+
def variance(self) -> jnp.ndarray:
3198+
return jnp.broadcast_to(self.covariance_row[..., 0, None], self.shape())
3199+
3200+
@staticmethod
3201+
def infer_shapes(
3202+
loc: tuple = (), covariance_row: tuple = None, covariance_rfft: tuple = None
3203+
):
3204+
assert_one_of(covariance_row=covariance_row, covariance_rfft=covariance_rfft)
3205+
for cov in [covariance_rfft, covariance_row]:
3206+
if cov is not None:
3207+
batch_shape = jnp.broadcast_shapes(loc[:-1], cov[:-1])
3208+
event_shape = loc[-1:]
3209+
return batch_shape, event_shape
3210+
3211+
def entropy(self):
3212+
(n,) = self.event_shape
3213+
log_abs_det_jacobian = 2 * jnp.log(2) * ((n - 1) // 2) - jnp.log(n) * n
3214+
return self.base_dist.entropy() + log_abs_det_jacobian / 2

numpyro/distributions/kl.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
from numpyro.distributions.continuous import (
3535
Beta,
36+
CirculantNormal,
3637
Dirichlet,
3738
Gamma,
3839
Kumaraswamy,
@@ -183,6 +184,27 @@ def _shapes_are_broadcastable(first_shape, second_shape):
183184
return 0.5 * (tr + t1 - D - log_det_ratio)
184185

185186

187+
@dispatch(Independent, CirculantNormal)
188+
def kl_divergence(p: Independent, q: CirculantNormal):
189+
# We can only calculate the KL divergence if the base distribution is normal.
190+
if not isinstance(p.base_dist, Normal) or p.reinterpreted_batch_ndims != 1:
191+
raise NotImplementedError
192+
193+
residual = q.mean - p.mean
194+
n = residual.shape[-1]
195+
log_covariance_rfft = jnp.log(q.covariance_rfft)
196+
return (
197+
jnp.vecdot(
198+
residual, jnp.fft.irfft(jnp.fft.rfft(residual) / q.covariance_rfft, n)
199+
)
200+
+ jnp.fft.irfft(1 / q.covariance_rfft, n)[..., 0] * jnp.sum(p.variance, axis=-1)
201+
+ log_covariance_rfft.sum(axis=-1)
202+
+ log_covariance_rfft[..., 1 : (n + 1) // 2].sum(axis=-1)
203+
- jnp.log(p.variance).sum(axis=-1)
204+
- n
205+
) / 2
206+
207+
186208
@dispatch(Beta, Beta)
187209
def kl_divergence(p, q):
188210
# From https://en.wikipedia.org/wiki/Beta_distribution#Quantities_of_information_(entropy)

numpyro/distributions/transforms.py

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
"LowerCholeskyTransform",
4141
"ScaledUnitLowerCholeskyTransform",
4242
"LowerCholeskyAffine",
43+
"PackRealFastFourierCoefficientsTransform",
4344
"PermuteTransform",
4445
"PowerTransform",
4546
"RealFastFourierTransform",
@@ -1311,10 +1312,15 @@ def inverse_shape(self, shape: tuple) -> tuple:
13111312
def log_abs_det_jacobian(
13121313
self, x: jnp.ndarray, y: jnp.ndarray, intermediates: None = None
13131314
) -> jnp.ndarray:
1314-
shape = jnp.broadcast_shapes(
1315+
batch_shape = jnp.broadcast_shapes(
13151316
x.shape[: -self.transform_ndims], y.shape[: -self.transform_ndims]
13161317
)
1317-
return jnp.zeros_like(x, shape=shape)
1318+
event_shape = x.shape[-self.transform_ndims :]
1319+
size = math.prod(event_shape)
1320+
q = math.prod(2 - size % 2 for size in event_shape)
1321+
return jnp.broadcast_to(
1322+
(size * jnp.log(size) - jnp.log(2) * (size - q)) / 2, batch_shape
1323+
)
13181324

13191325
def tree_flatten(self):
13201326
aux_data = {
@@ -1339,6 +1345,74 @@ def __eq__(self, other):
13391345
)
13401346

13411347

1348+
class PackRealFastFourierCoefficientsTransform(Transform):
1349+
"""
1350+
Transform a real vector to complex coefficients of a real fast Fourier transform.
1351+
1352+
:param transform_shape: Shape of the real vector, defaults to the input size.
1353+
"""
1354+
1355+
domain = constraints.real_vector
1356+
codomain = constraints.independent(constraints.complex, 1)
1357+
1358+
def __init__(self, transform_shape: tuple = None) -> None:
1359+
assert transform_shape is None or len(transform_shape) == 1, (
1360+
"Packing Fourier coefficients is only implemented for vectors."
1361+
)
1362+
self.shape = transform_shape
1363+
1364+
def tree_flatten(self):
1365+
return (), ((), {"shape": self.shape})
1366+
1367+
def forward_shape(self, shape: tuple) -> tuple:
1368+
*batch_shape, n = shape
1369+
assert self.shape is None or self.shape == (n,), (
1370+
f"`shape` must be `None` or `{self.shape}. Got `{shape}`."
1371+
)
1372+
n_rfft = n // 2 + 1
1373+
return (*batch_shape, n_rfft)
1374+
1375+
def inverse_shape(self, shape: tuple) -> tuple:
1376+
*batch_shape, n_rfft = shape
1377+
assert self.shape is not None, (
1378+
"Shape must be specified in `__init__` for inverse transform."
1379+
)
1380+
(n,) = self.shape
1381+
assert n_rfft == n // 2 + 1
1382+
return (*batch_shape, n)
1383+
1384+
def log_abs_det_jacobian(
1385+
self, x: jnp.ndarray, y: jnp.ndarray, intermediates: None = None
1386+
) -> jnp.ndarray:
1387+
shape = jnp.broadcast_shapes(x.shape[:-1], y.shape[:-1])
1388+
return jnp.zeros_like(x, shape=shape)
1389+
1390+
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
1391+
assert self.shape is None or self.shape == x.shape[-1:]
1392+
n = x.shape[-1]
1393+
n_real = n // 2 + 1
1394+
n_imag = n - n_real
1395+
complex_dtype = jnp.result_type(x.dtype, jnp.complex64)
1396+
return (
1397+
x[..., :n_real]
1398+
.astype(complex_dtype)
1399+
.at[..., 1 : 1 + n_imag]
1400+
.add(1j * x[..., n_real:])
1401+
)
1402+
1403+
def _inverse(self, y: jnp.ndarray) -> jnp.ndarray:
1404+
(n,) = self.shape
1405+
n_real = n // 2 + 1
1406+
n_imag = n - n_real
1407+
return jnp.concatenate([y.real, y.imag[..., 1 : n_imag + 1]], axis=-1)
1408+
1409+
def __eq__(self, other) -> bool:
1410+
return (
1411+
isinstance(other, PackRealFastFourierCoefficientsTransform)
1412+
and self.shape == other.shape
1413+
)
1414+
1415+
13421416
class RecursiveLinearTransform(Transform):
13431417
"""
13441418
Apply a linear transformation recursively such that

0 commit comments

Comments
 (0)