Skip to content

Commit f08c6bc

Browse files
authored
Merge pull request #252 from astro-informatics/mmg/vectorized-signal-generators
Vectorize signal generator functions
2 parents 3ff7699 + 17055fb commit f08c6bc

File tree

4 files changed

+249
-38
lines changed

4 files changed

+249
-38
lines changed

s2fft/transforms/wigner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,9 @@ def inverse_numpy(
156156
)
157157
fban = np.zeros(samples.f_shape(L, N, sampling, nside), dtype=np.complex128)
158158

159+
# Copy flmn argument to avoid in-place updates being propagated back to caller
160+
flmn = flmn.copy()
161+
159162
flmn[:, L_lower:] = np.einsum(
160163
"...nlm,...l->...nlm",
161164
flmn[:, L_lower:],

s2fft/utils/signal_generator.py

Lines changed: 126 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,80 @@
1+
from __future__ import annotations
2+
13
import numpy as np
24
import torch
35

46
from s2fft.sampling import s2_samples as samples
57
from s2fft.sampling import so3_samples as wigner_samples
68

79

10+
def complex_normal(
11+
rng: np.random.Generator,
12+
size: int | tuple[int],
13+
var: float,
14+
) -> np.ndarray:
15+
"""
16+
Generate array of samples from zero-mean complex normal distribution.
17+
18+
For `z ~ ComplexNormal(0, var)` we have that `imag(z) ~ Normal(0, var/2)` and
19+
`real(z) ~ Normal(0, var/2)` where `Normal(μ, σ²)` is the (real-valued) normal
20+
distribution with mean parameter `μ` and variance parameter `σ²`.
21+
22+
Args:
23+
rng: Numpy random generator object to generate samples using.
24+
size: Output shape of array to generate.
25+
var: Variance of complex normal distribution to generate samples from.
26+
27+
Returns:
28+
Complex-valued array of shape `size` contained generated samples.
29+
30+
"""
31+
return (rng.standard_normal(size) + 1j * rng.standard_normal(size)) * (
32+
var / 2
33+
) ** 0.5
34+
35+
36+
def complex_el_and_m_indices(L: int, min_el: int) -> tuple[np.ndarray, np.ndarray]:
37+
"""
38+
Generate pairs of el, m indices for accessing complex harmonic coefficients.
39+
40+
Equivalent to nested list-comprehension based implementation
41+
42+
```
43+
el_indices, m_indices = np.array(
44+
[(el, m) for el in range(min_el, L) for m in range(1, el + 1)]
45+
).T
46+
```
47+
48+
For `L, min_el = 1024, 0`, this implementation is around 80x quicker in
49+
benchmarks compared to list-comprehension implementation.
50+
51+
Args:
52+
L: Harmonic band-limit.
53+
min_el: Inclusive lower-bound for el indices.
54+
55+
Returns:
56+
Tuple `(el_indices, m_indices)` with both entries 1D integer-valued NumPy arrays
57+
of same size, with values of corresponding entries corresponding to pairs of
58+
el and m indices.
59+
60+
"""
61+
el_indices, m_indices = np.tril_indices(m=L, k=-1, n=L)
62+
m_indices += 1
63+
if min_el > 0:
64+
in_range_el = el_indices >= min_el
65+
el_indices = el_indices[in_range_el]
66+
m_indices = m_indices[in_range_el]
67+
return el_indices, m_indices
68+
69+
870
def generate_flm(
971
rng: np.random.Generator,
1072
L: int,
1173
L_lower: int = 0,
1274
spin: int = 0,
1375
reality: bool = False,
1476
using_torch: bool = False,
15-
) -> np.ndarray:
77+
) -> np.ndarray | torch.Tensor:
1678
r"""
1779
Generate a 2D set of random harmonic coefficients.
1880
@@ -37,20 +99,24 @@ def generate_flm(
3799
38100
"""
39101
flm = np.zeros(samples.flm_shape(L), dtype=np.complex128)
40-
41-
for el in range(max(L_lower, abs(spin)), L):
42-
if reality:
43-
flm[el, 0 + L - 1] = rng.normal()
44-
else:
45-
flm[el, 0 + L - 1] = rng.normal() + 1j * rng.normal()
46-
47-
for m in range(1, el + 1):
48-
flm[el, m + L - 1] = rng.normal() + 1j * rng.normal()
49-
if reality:
50-
flm[el, -m + L - 1] = (-1) ** m * np.conj(flm[el, m + L - 1])
51-
else:
52-
flm[el, -m + L - 1] = rng.normal() + 1j * rng.normal()
53-
102+
min_el = max(L_lower, abs(spin))
103+
# m = 0 coefficients are always real
104+
flm[min_el:L, L - 1] = rng.standard_normal(L - min_el)
105+
# Construct arrays of m and el indices for entries in flm corresponding to complex-
106+
# valued coefficients (m > 0)
107+
el_indices, m_indices = complex_el_and_m_indices(L, min_el)
108+
len_indices = len(m_indices)
109+
# Generate independent complex coefficients for positive m
110+
flm[el_indices, L - 1 + m_indices] = complex_normal(rng, len_indices, var=2)
111+
if reality:
112+
# Real-valued signal so set complex coefficients for negative m using conjugate
113+
# symmetry such that flm[el, L - 1 - m] = (-1)**m * flm[el, L - 1 + m].conj
114+
flm[el_indices, L - 1 - m_indices] = (-1) ** m_indices * (
115+
flm[el_indices, L - 1 + m_indices].conj()
116+
)
117+
else:
118+
# Non-real signal so generate independent complex coefficients for negative m
119+
flm[el_indices, L - 1 - m_indices] = complex_normal(rng, len_indices, var=2)
54120
return torch.from_numpy(flm) if using_torch else flm
55121

56122

@@ -61,7 +127,7 @@ def generate_flmn(
61127
L_lower: int = 0,
62128
reality: bool = False,
63129
using_torch: bool = False,
64-
) -> np.ndarray:
130+
) -> np.ndarray | torch.Tensor:
65131
r"""
66132
Generate a 3D set of random Wigner coefficients.
67133
@@ -87,26 +153,49 @@ def generate_flmn(
87153
88154
"""
89155
flmn = np.zeros(wigner_samples.flmn_shape(L, N), dtype=np.complex128)
90-
91156
for n in range(-N + 1, N):
92-
for el in range(max(L_lower, abs(n)), L):
93-
if reality:
94-
flmn[N - 1 + n, el, 0 + L - 1] = rng.normal()
95-
flmn[N - 1 - n, el, 0 + L - 1] = (-1) ** n * flmn[
96-
N - 1 + n,
97-
el,
98-
0 + L - 1,
99-
]
100-
else:
101-
flmn[N - 1 + n, el, 0 + L - 1] = rng.normal() + 1j * rng.normal()
102-
103-
for m in range(1, el + 1):
104-
flmn[N - 1 + n, el, m + L - 1] = rng.normal() + 1j * rng.normal()
105-
if reality:
106-
flmn[N - 1 - n, el, -m + L - 1] = (-1) ** (m + n) * np.conj(
107-
flmn[N - 1 + n, el, m + L - 1]
108-
)
109-
else:
110-
flmn[N - 1 + n, el, -m + L - 1] = rng.normal() + 1j * rng.normal()
111-
157+
min_el = max(L_lower, abs(n))
158+
# Separately deal with m = 0 case
159+
if reality:
160+
if n == 0:
161+
# For m = n = 0
162+
# flmn[N - 1, el, L - 1] = flmn[N - 1, el, L - 1].conj (real-valued)
163+
# Generate independent real coefficients for n = 0
164+
flmn[N - 1, min_el:L, L - 1] = rng.standard_normal(L - min_el)
165+
elif n > 0:
166+
# Generate independent complex coefficients for positive n
167+
flmn[N - 1 + n, min_el:L, L - 1] = complex_normal(
168+
rng, L - min_el, var=2
169+
)
170+
# For m = 0, n > 0
171+
# flmn[N - 1 - n, el, L - 1] = (-1)**n * flmn[N - 1 + n, el, L - 1].conj
172+
flmn[N - 1 - n, min_el:L, L - 1] = (-1) ** n * (
173+
flmn[N - 1 + n, min_el:L, L - 1].conj()
174+
)
175+
else:
176+
flmn[N - 1 + n, min_el:L, L - 1] = complex_normal(rng, L - min_el, var=2)
177+
# Construct arrays of m and el indices for entries in flmn slices for n
178+
# corresponding to complex-valued coefficients (m > 0)
179+
el_indices, m_indices = complex_el_and_m_indices(L, min_el)
180+
len_indices = len(m_indices)
181+
# Generate independent complex coefficients for positive m
182+
flmn[N - 1 + n, el_indices, L - 1 + m_indices] = complex_normal(
183+
rng, len_indices, var=2
184+
)
185+
if reality:
186+
# Real-valued signal so set complex coefficients for negative m using
187+
# conjugate symmetry relationship
188+
# flmn[N - 1 - n, el, L - 1 - m] =
189+
# (-1)**(m + n) * flmn[N - 1 + n, el, L - 1 + m].conj
190+
# As (m_indices + n) can be negative use floating point value (-1.0) as
191+
# base of exponentation operation to avoid Numpy
192+
# 'ValueError: Integers to negative integer powers are not allowed' error
193+
flmn[N - 1 - n, el_indices, L - 1 - m_indices] = (-1.0) ** (
194+
m_indices + n
195+
) * flmn[N - 1 + n, el_indices, L - 1 + m_indices].conj()
196+
else:
197+
# Complex signal so generate independent complex coefficients for negative m
198+
flmn[N - 1 + n, el_indices, L - 1 - m_indices] = complex_normal(
199+
rng, len_indices, var=2
200+
)
112201
return torch.from_numpy(flmn) if using_torch else flmn

tests/test_signal_generator.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import numpy as np
2+
import pytest
3+
4+
import s2fft
5+
import s2fft.sampling as smp
6+
import s2fft.utils.signal_generator as gen
7+
8+
L_values_to_test = [6, 7, 16]
9+
L_lower_to_test = [0, 1]
10+
spin_to_test = [-2, 0, 1]
11+
reality_values_to_test = [False, True]
12+
13+
14+
@pytest.mark.parametrize("size", (10, 100, 1000))
15+
@pytest.mark.parametrize("var", (1, 2))
16+
def test_complex_normal(rng, size, var):
17+
samples = gen.complex_normal(rng, size, var)
18+
assert samples.dtype == np.complex128
19+
assert samples.size == size
20+
mean = samples.mean()
21+
# Error in real + imag components of mean estimate ~ Normal(0, (var / 2) / size)
22+
# Therefore difference between mean estimate and true zero value should be
23+
# less than 3 * sqrt(var / (2 * size)) with probability 0.997
24+
mean_error_tol = 3 * (var / (2 * size)) ** 0.5
25+
assert abs(mean.imag) < mean_error_tol and abs(mean.real) < mean_error_tol
26+
# If S is (unbiased) sample variance estimate then (size - 1) * S / var is a
27+
# chi-squared distributed random variable with (size - 1) degrees of freedom
28+
# For size >> 1, S ~approx Normal(var, 2 * var**2 / (size - 1)) so error in
29+
# variance estimate should be less than 3 * sqrt(2 * var**2 / (size - 1))
30+
# with high probability
31+
assert abs(samples.var(ddof=1) - var) < 3 * (2 * var**2 / (size - 1)) ** 0.5
32+
33+
34+
@pytest.mark.parametrize("L", L_values_to_test)
35+
@pytest.mark.parametrize("min_el", [0, 1])
36+
def test_complex_el_and_m_indices(L, min_el):
37+
expected_el_indices, expected_m_indices = np.array(
38+
[(el, m) for el in range(min_el, L) for m in range(1, el + 1)]
39+
).T
40+
el_indices, m_indices = gen.complex_el_and_m_indices(L, min_el)
41+
assert (el_indices == expected_el_indices).all()
42+
assert (m_indices == expected_m_indices).all()
43+
44+
45+
def check_flm_zeros(flm, L, min_el):
46+
for el in range(L):
47+
for m in range(L):
48+
if el < min_el or m > el:
49+
assert flm[el, L - 1 + m] == flm[el, L - 1 - m] == 0
50+
51+
52+
def check_flm_conjugate_symmetry(flm, L, min_el):
53+
for el in range(min_el, L):
54+
for m in range(el + 1):
55+
assert flm[el, L - 1 - m] == (-1) ** m * flm[el, L - 1 + m].conj()
56+
57+
58+
@pytest.mark.parametrize("L", L_values_to_test)
59+
@pytest.mark.parametrize("L_lower", L_lower_to_test)
60+
@pytest.mark.parametrize("spin", spin_to_test)
61+
@pytest.mark.parametrize("reality", reality_values_to_test)
62+
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
63+
def test_generate_flm(rng, L, L_lower, spin, reality):
64+
if reality and spin != 0:
65+
pytest.skip("Reality only valid for scalar fields (spin=0).")
66+
flm = gen.generate_flm(rng, L, L_lower, spin, reality)
67+
assert flm.shape == smp.s2_samples.flm_shape(L)
68+
assert flm.dtype == np.complex128
69+
assert np.isfinite(flm).all()
70+
check_flm_zeros(flm, L, max(L_lower, abs(spin)))
71+
if reality:
72+
check_flm_conjugate_symmetry(flm, L, max(L_lower, abs(spin)))
73+
f_complex = s2fft.inverse(flm, L, spin=spin, reality=False, L_lower=L_lower)
74+
assert np.allclose(f_complex.imag, 0)
75+
f_real = s2fft.inverse(flm, L, spin=spin, reality=True, L_lower=L_lower)
76+
assert np.allclose(f_complex.real, f_real)
77+
78+
79+
def check_flmn_zeros(flmn, L, N, L_lower):
80+
for n in range(-N + 1, N):
81+
min_el = max(L_lower, abs(n))
82+
for el in range(L):
83+
for m in range(L):
84+
if el < min_el or m > el:
85+
assert (
86+
flmn[N - 1 + n, el, L - 1 + m]
87+
== flmn[N - 1 + n, el, L - 1 - m]
88+
== 0
89+
)
90+
91+
92+
def check_flmn_conjugate_symmetry(flmn, L, N, L_lower):
93+
for n in range(-N + 1, N):
94+
min_el = max(L_lower, abs(n))
95+
for el in range(min_el, L):
96+
for m in range(el + 1):
97+
assert (
98+
flmn[N - 1 - n, el, L - 1 - m]
99+
== (-1) ** (m + n) * flmn[N - 1 + n, el, L - 1 + m].conj()
100+
)
101+
102+
103+
@pytest.mark.parametrize("L", L_values_to_test)
104+
@pytest.mark.parametrize("N", [1, 2, 3])
105+
@pytest.mark.parametrize("L_lower", L_lower_to_test)
106+
@pytest.mark.parametrize("reality", reality_values_to_test)
107+
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
108+
def test_generate_flmn(rng, L, N, L_lower, reality):
109+
flmn = gen.generate_flmn(rng, L, N, L_lower, reality)
110+
assert flmn.shape == smp.so3_samples.flmn_shape(L, N)
111+
assert flmn.dtype == np.complex128
112+
assert np.isfinite(flmn).all()
113+
check_flmn_zeros(flmn, L, N, L_lower)
114+
if reality:
115+
check_flmn_conjugate_symmetry(flmn, L, N, L_lower)
116+
f_complex = s2fft.wigner.inverse(flmn, L, N, reality=False, L_lower=L_lower)
117+
assert np.allclose(f_complex.imag, 0)
118+
f_real = s2fft.wigner.inverse(flmn, L, N, reality=True, L_lower=L_lower)
119+
assert np.allclose(f_complex.real, f_real)

tests/test_spherical_precompute.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def test_transform_inverse_high_spin(
306306
kernel = c.spin_spherical_kernel(L, spin, reality, sampling, forward=False)
307307

308308
f = inverse(flm, L, spin, kernel, sampling, reality, "numpy")
309-
tol = 1e-8 if sampling.lower() in ["dh", "gl"] else 1e-12
309+
tol = 1e-8 if sampling.lower() in ["dh", "gl"] else 1e-11
310310
np.testing.assert_allclose(f, f_check, atol=tol, rtol=tol)
311311

312312

0 commit comments

Comments
 (0)