Skip to content

Commit e2332e7

Browse files
committed
generate flm from spectra
1 parent 3ed665d commit e2332e7

File tree

3 files changed

+154
-0
lines changed

3 files changed

+154
-0
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ tests = [
7272
"numpy<2", # Required currently due to lack of Numpy v2 compatible pyssht release
7373
"pytest",
7474
"pytest-cov",
75+
"pytest-rerunfailures",
7576
"so3",
7677
"pyssht",
7778
]

s2fft/utils/signal_generator.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
from s2fft.sampling import s2_samples as samples
77
from s2fft.sampling import so3_samples as wigner_samples
88

9+
TYPE_CHECKING = False
10+
if TYPE_CHECKING:
11+
import jax
12+
913

1014
def complex_normal(
1115
rng: np.random.Generator,
@@ -210,3 +214,60 @@ def generate_flmn(
210214
rng, len_indices, var=2
211215
)
212216
return torch.from_numpy(flmn) if using_torch else flmn
217+
218+
219+
def generate_flm_from_spectra(
220+
rng: np.random.Generator,
221+
spectra: np.ndarray | jax.Array,
222+
) -> np.ndarray | jax.Array:
223+
r"""
224+
Generate a stack of random harmonic coefficients from power spectra.
225+
226+
The input power spectra must be a stack of shape *(K, K, L)* where
227+
*K* is the number of fields to be sampled, and *L* is the harmonic
228+
band-limit.
229+
230+
Args:
231+
rng (Generator): Random number generator.
232+
233+
spectra (np.ndarray | jax.Array): Stack of angular power spectra.
234+
235+
Returns:
236+
np.ndarray | jax.Array: A stack of random spherical harmonic
237+
coefficients with the given power spectra.
238+
239+
"""
240+
# get the Array API namespace from spectra
241+
xp = spectra.__array_namespace__()
242+
243+
# check input
244+
if spectra.ndim != 3 or spectra.shape[0] != spectra.shape[1]:
245+
raise ValueError("shape of spectra must be (K, K, L)")
246+
247+
# K is the number of fields, L is the band limit
248+
*_, K, L = spectra.shape
249+
250+
# permute shape (K, K, L) -> (L, K, K)
251+
spectra = xp.permute_dims(spectra, (2, 0, 1))
252+
253+
# SVD for matrix square root
254+
# not using cholesky() here because matrix may be semi-definite
255+
# divides spectra by 2 for correct amplitude
256+
u, s, vh = xp.linalg.svd(spectra / 2, full_matrices=False)
257+
258+
# compute the matrix square root for sampling
259+
a = u @ (xp.sqrt(s[..., None]) * vh)
260+
261+
# permute shape (L, K, K) -> (K, K, L)
262+
a = xp.permute_dims(a, (1, 2, 0))
263+
264+
# sample the random coefficients
265+
# always use reality=True, this could be real fields or E/B modes
266+
# shape of flm is (K, L, M)
267+
flm = generate_flm(rng, L, reality=True, size=K)
268+
269+
# compute the matrix multiplication by hand, because we have a mix of
270+
# contraction (dim=K) and broadcasting (dim=L)
271+
flm = (a[..., None] * flm).sum(axis=-3)
272+
273+
return flm

tests/test_signal_generator.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import jax.numpy as jnp
12
import numpy as np
23
import pytest
4+
from jax.test_util import check_grads
35

46
import s2fft
57
import s2fft.sampling as smp
@@ -143,3 +145,93 @@ def test_generate_flmn(rng, L, N, L_lower, reality):
143145
assert np.allclose(f_complex.imag, 0)
144146
f_real = s2fft.wigner.inverse(flmn, L, N, reality=True, L_lower=L_lower)
145147
assert np.allclose(f_complex.real, f_real)
148+
149+
150+
def gaussian_covariance(spectra):
151+
"""Gaussian covariance for a stack of spectra.
152+
153+
If the shape of *spectra* is *(K, K, L)*, the shape of the
154+
covariance is *(L, C, C)*, where ``C = K * (K + 1) // 2``
155+
is the number of independent spectra.
156+
157+
"""
158+
_, K, L = spectra.shape
159+
row, col = np.tril_indices(K)
160+
cov = np.zeros((L, row.size, col.size))
161+
ell = np.arange(L)
162+
for i, j in np.ndindex(row.size, col.size):
163+
cov[:, i, j] = (
164+
spectra[row[i], row[j]] * spectra[col[i], col[j]]
165+
+ spectra[row[i], col[j]] * spectra[col[i], row[j]]
166+
) / (2 * ell + 1)
167+
return cov
168+
169+
170+
@pytest.mark.flaky
171+
@pytest.mark.parametrize("L", L_values_to_test)
172+
@pytest.mark.parametrize("xp", [np, jnp])
173+
def test_generate_flm_from_spectra(rng, L, xp):
174+
# number of fields to generate
175+
K = 4
176+
177+
# correlation matrix for fields, applied to all ell
178+
corr = xp.asarray(
179+
[
180+
[1.0, 0.1, -0.1, 0.1],
181+
[0.1, 1.0, 0.1, -0.1],
182+
[-0.1, 0.1, 1.0, 0.1],
183+
[0.1, -0.1, 0.1, 1.0],
184+
],
185+
)
186+
187+
ell = xp.arange(L)
188+
189+
# auto-spectra are power laws
190+
powers = xp.arange(1, K + 1)
191+
auto = 1 / (2 * ell + 1) ** powers[:, None]
192+
193+
# compute the spectra from auto and corr
194+
spectra = xp.sqrt(auto[:, None, :] * auto[None, :, :]) * corr[:, :, None]
195+
assert spectra.shape == (K, K, L)
196+
197+
# generate random flm from spectra
198+
flm = s2fft.utils.signal_generator.generate_flm_from_spectra(rng, spectra)
199+
assert flm.shape == (K, L, 2 * L - 1)
200+
201+
# compute the realised spectra
202+
re, im = flm.real, flm.imag
203+
result = (
204+
re[None, :, :, :] * re[:, None, :, :] + im[None, :, :, :] * im[:, None, :, :]
205+
)
206+
result = result.sum(axis=-1) / (2 * ell + 1)
207+
208+
# compute covariance of sampled spectra
209+
cov = gaussian_covariance(spectra)
210+
211+
# data vector, remove duplicate entries, and put L dim first
212+
x = result - spectra
213+
x = x[np.tril_indices(K)]
214+
x = x.T
215+
216+
# compute chi2/n of realised spectra
217+
y = xp.linalg.solve(cov, x[..., None])[..., 0]
218+
n = x.size
219+
chi2_n = (x * y).sum() / n
220+
221+
# make sure chi2/n is as expected
222+
sigma = np.sqrt(2 / n)
223+
assert np.fabs(chi2_n - 1.0) < 3 * sigma
224+
225+
226+
@pytest.mark.parametrize("L", L_values_to_test)
227+
def test_generate_flm_from_spectra_grads(L):
228+
# fixed set of power spectra
229+
ell = jnp.arange(L)
230+
cl = 1 / (2 * ell + 1)
231+
spectra = cl.reshape(1, 1, L)
232+
233+
def func(x):
234+
rng = np.random.default_rng(42)
235+
return s2fft.utils.signal_generator.generate_flm_from_spectra(rng, x)
236+
237+
check_grads(func, (spectra,), 1)

0 commit comments

Comments
 (0)