Skip to content

Commit 3ed665d

Browse files
committed
add size parameter to generate_flm()
1 parent 1d5fa15 commit 3ed665d

File tree

2 files changed

+44
-7
lines changed

2 files changed

+44
-7
lines changed

s2fft/utils/signal_generator.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def generate_flm(
7474
spin: int = 0,
7575
reality: bool = False,
7676
using_torch: bool = False,
77+
size: tuple[int, ...] | int | None = None,
7778
) -> np.ndarray | torch.Tensor:
7879
r"""
7980
Generate a 2D set of random harmonic coefficients.
@@ -94,29 +95,39 @@ def generate_flm(
9495
9596
using_torch (bool, optional): Desired frontend functionality. Defaults to False.
9697
98+
size (tuple[int, ...] | int | None, optional): Shape of realisations.
99+
97100
Returns:
98101
np.ndarray: Random set of spherical harmonic coefficients.
99102
100103
"""
101-
flm = np.zeros(samples.flm_shape(L), dtype=np.complex128)
104+
# always turn size into a tuple of int
105+
if size is None:
106+
size = ()
107+
elif isinstance(size, int):
108+
size = (size,)
109+
elif not (isinstance(size, tuple) and all(isinstance(_, int) for _ in size)):
110+
raise TypeError("size must be int or tuple of int")
111+
112+
flm = np.zeros((*size, *samples.flm_shape(L)), dtype=np.complex128)
102113
min_el = max(L_lower, abs(spin))
103114
# m = 0 coefficients are always real
104-
flm[min_el:L, L - 1] = rng.standard_normal(L - min_el)
115+
flm[..., min_el:L, L - 1] = rng.standard_normal((*size, L - min_el))
105116
# Construct arrays of m and el indices for entries in flm corresponding to complex-
106117
# valued coefficients (m > 0)
107118
el_indices, m_indices = complex_el_and_m_indices(L, min_el)
108-
len_indices = len(m_indices)
119+
rand_size = (*size, len(m_indices))
109120
# Generate independent complex coefficients for positive m
110-
flm[el_indices, L - 1 + m_indices] = complex_normal(rng, len_indices, var=2)
121+
flm[..., el_indices, L - 1 + m_indices] = complex_normal(rng, rand_size, var=2)
111122
if reality:
112123
# Real-valued signal so set complex coefficients for negative m using conjugate
113124
# symmetry such that flm[el, L - 1 - m] = (-1)**m * flm[el, L - 1 + m].conj
114-
flm[el_indices, L - 1 - m_indices] = (-1) ** m_indices * (
115-
flm[el_indices, L - 1 + m_indices].conj()
125+
flm[..., el_indices, L - 1 - m_indices] = (-1) ** m_indices * (
126+
flm[..., el_indices, L - 1 + m_indices].conj()
116127
)
117128
else:
118129
# Non-real signal so generate independent complex coefficients for negative m
119-
flm[el_indices, L - 1 - m_indices] = complex_normal(rng, len_indices, var=2)
130+
flm[..., el_indices, L - 1 - m_indices] = complex_normal(rng, rand_size, var=2)
120131
return torch.from_numpy(flm) if using_torch else flm
121132

122133

tests/test_signal_generator.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ def check_flm_conjugate_symmetry(flm, L, min_el):
5555
assert flm[el, L - 1 - m] == (-1) ** m * flm[el, L - 1 + m].conj()
5656

5757

58+
def check_flm_unequal(flm1, flm2, L, min_el):
59+
"""assert that two passed flm are elementwise unequal"""
60+
for el in range(L):
61+
for m in range(L):
62+
if not (el < min_el or m > el):
63+
assert flm1[el, L - 1 + m] != flm2[el, L - 1 - m]
64+
65+
5866
@pytest.mark.parametrize("L", L_values_to_test)
5967
@pytest.mark.parametrize("L_lower", L_lower_to_test)
6068
@pytest.mark.parametrize("spin", spin_to_test)
@@ -76,6 +84,24 @@ def test_generate_flm(rng, L, L_lower, spin, reality):
7684
assert np.allclose(f_complex.real, f_real)
7785

7886

87+
@pytest.mark.parametrize("L", L_values_to_test)
88+
@pytest.mark.parametrize("L_lower", L_lower_to_test)
89+
@pytest.mark.parametrize("spin", spin_to_test)
90+
@pytest.mark.parametrize("reality", reality_values_to_test)
91+
def test_generate_flm_size(rng, L, L_lower, spin, reality):
92+
if reality and spin != 0:
93+
pytest.skip("Reality only valid for scalar fields (spin=0).")
94+
95+
flm = gen.generate_flm(rng, L, L_lower, spin, reality, size=2)
96+
assert flm.shape == (2,) + smp.s2_samples.flm_shape(L)
97+
check_flm_zeros(flm[0], L, max(L_lower, abs(spin)))
98+
check_flm_zeros(flm[1], L, max(L_lower, abs(spin)))
99+
check_flm_unequal(flm[0], flm[1], L, max(L_lower, abs(spin)))
100+
101+
flm = gen.generate_flm(rng, L, L_lower, spin, reality, size=(3, 4))
102+
assert flm.shape == (3, 4) + smp.s2_samples.flm_shape(L)
103+
104+
79105
def check_flmn_zeros(flmn, L, N, L_lower):
80106
for n in range(-N + 1, N):
81107
min_el = max(L_lower, abs(n))

0 commit comments

Comments
 (0)