Skip to content

Commit 07c8c31

Browse files
committed
Make coefficients real for m=0 and factor out functions
1 parent cc75c46 commit 07c8c31

File tree

1 file changed

+75
-15
lines changed

1 file changed

+75
-15
lines changed

s2fft/utils/signal_generator.py

Lines changed: 75 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,66 @@
77
from s2fft.sampling import so3_samples as wigner_samples
88

99

10+
def complex_normal(
11+
rng: np.random.Generator,
12+
size: int | tuple[int],
13+
var: float,
14+
) -> np.ndarray:
15+
"""
16+
Generate array of samples from zero-mean complex normal distribution.
17+
18+
For `z ~ ComplexNormal(0, var)` we have that `imag(z) ~ Normal(0, var/2)` and
19+
`real(z) ~ Normal(0, var/2)` where `Normal(μ, σ²)` is the (real-valued) normal
20+
distribution with mean parameter `μ` and variance parameter `σ²`.
21+
22+
Args:
23+
rng: Numpy random generator object to generate samples using.
24+
size: Output shape of array to generate.
25+
var: Variance of complex normal distribution to generate samples from.
26+
27+
Returns:
28+
Complex-valued array of shape `size` contained generated samples.
29+
30+
"""
31+
return (rng.standard_normal(size) + 1j * rng.standard_normal(size)) * (
32+
var / 2
33+
) ** 0.5
34+
35+
36+
def complex_el_and_m_indices(L: int, min_el: int) -> tuple[np.ndarray, np.ndarray]:
37+
"""
38+
Generate pairs of el, m indices for accessing complex harmonic coefficients.
39+
40+
Equivalent to nested list-comprehension based implementation
41+
42+
```
43+
el_indices, m_indices = np.array(
44+
[(el, m) for el in range(min_el, L) for m in range(1, el + 1))]
45+
)
46+
```
47+
48+
For `L, min_el = 1024, 0`, this implementation is around 80x quicker in
49+
benchmarks compared to list-comprehension implementation.
50+
51+
Args:
52+
L: Harmonic band-limit.
53+
min_el: Inclusive lower-bound for el indices.
54+
55+
Returns:
56+
Tuple `(el_indices, m_indices)` with both entries 1D integer-valued NumPy arrays
57+
of same size, with values of corresponding entries corresponding to pairs of
58+
el and m indices.
59+
60+
"""
61+
el_indices, m_indices = np.tril_indices(m=L, k=-1, n=L)
62+
m_indices += 1
63+
if min_el > 0:
64+
in_range_el = el_indices >= min_el
65+
el_indices = el_indices[in_range_el]
66+
m_indices = m_indices[in_range_el]
67+
return el_indices, m_indices
68+
69+
1070
def generate_flm(
1171
rng: np.random.Generator,
1272
L: int,
@@ -39,24 +99,24 @@ def generate_flm(
3999
40100
"""
41101
flm = np.zeros(samples.flm_shape(L), dtype=np.complex128)
42-
43102
min_el = max(L_lower, abs(spin))
103+
# m = 0 coefficients are always real
44104
flm[min_el:L, L - 1] = rng.standard_normal(L - min_el)
45-
if not reality:
46-
flm[min_el:L, L - 1] += 1j * rng.standard_normal(L - min_el)
47-
m_indices, el_indices = np.triu_indices(n=L, k=1, m=L) + np.array([[1], [0]])
48-
if min_el > 0:
49-
in_range_el = el_indices >= min_el
50-
m_indices = m_indices[in_range_el]
51-
el_indices = el_indices[in_range_el]
105+
# Construct arrays of m and el indices for entries in flm corresponding to complex-
106+
# valued coefficients (m > 0)
107+
el_indices, m_indices = complex_el_and_m_indices(L, min_el)
52108
len_indices = len(m_indices)
53-
flm[el_indices, L - 1 - m_indices] = rng.standard_normal(
54-
len_indices
55-
) + 1j * rng.standard_normal(len_indices)
56-
flm[el_indices, L - 1 + m_indices] = (-1) ** m_indices * np.conj(
57-
flm[el_indices, L - 1 - m_indices]
58-
)
59-
109+
# Generate independent complex coefficients for positive m
110+
flm[el_indices, L - 1 + m_indices] = complex_normal(rng, len_indices, var=2)
111+
if reality:
112+
# Real-valued signal so set complex coefficients for negative m using conjugate
113+
# symmetry such that flm[el, L - 1 - m] = (-1)**m * flm[el, L - 1 + m].conj
114+
flm[el_indices, L - 1 - m_indices] = (-1) ** m_indices * (
115+
flm[el_indices, L - 1 + m_indices].conj()
116+
)
117+
else:
118+
# Non-real signal so generate independent complex coefficients for negative m
119+
flm[el_indices, L - 1 - m_indices] = complex_normal(rng, len_indices, var=2)
60120
return torch.from_numpy(flm) if using_torch else flm
61121

62122

0 commit comments

Comments
 (0)