Skip to content

Commit c764a98

Browse files
committed
add on-the-fly support for Fourier Wigner transforms
1 parent 876e090 commit c764a98

File tree

2 files changed

+160
-64
lines changed

2 files changed

+160
-64
lines changed

s2fft/precompute_transforms/fourier_wigner.py

Lines changed: 137 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
import numpy as np
55
from jax import jit
66

7+
from s2fft import recursions
8+
from s2fft.utils import quadrature, quadrature_jax
9+
710

811
def inverse_transform(
912
flmn: np.ndarray,
10-
DW: np.ndarray,
1113
L: int,
1214
N: int,
15+
DW: np.ndarray = None,
1316
reality: bool = False,
1417
sampling: str = "mw",
1518
) -> np.ndarray:
@@ -18,10 +21,11 @@ def inverse_transform(
1821
1922
Args:
2023
flmn (np.ndarray): Wigner coefficients.
21-
DW (Tuple[np.ndarray, np.ndarray]): Fourier coefficients of the reduced
22-
Wigner d-functions and the corresponding upsampled quadrature weights.
2324
L (int): Harmonic band-limit.
2425
N (int): Azimuthal band-limit.
26+
DW (Tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced
27+
Wigner d-functions and the corresponding upsampled quadrature weights.
28+
Defaults to None.
2529
reality (bool, optional): Whether the signal on the sphere is real. If so,
2630
conjugate symmetry is exploited to reduce computational costs.
2731
Defaults to False.
@@ -37,9 +41,6 @@ def inverse_transform(
3741
f"Fourier-Wigner algorithm does not support {sampling} sampling."
3842
)
3943

40-
# EXTRACT VARIOUS PRECOMPUTES
41-
Delta, _ = DW
42-
4344
# INDEX VALUES
4445
n_start_ind = N - 1 if reality else 0
4546
n_dim = N if reality else 2 * N - 1
@@ -54,13 +55,27 @@ def inverse_transform(
5455

5556
# Calculate fmna = i^(n-m)\sum_L Delta^l_am Delta^l_an f^l_mn(2l+1)/(8pi^2)
5657
x = np.zeros((xnlm_size, xnlm_size, n_dim), dtype=flmn.dtype)
57-
x[m_offset:, m_offset:] = np.einsum(
58-
"nlm,lam,lan,l->amn",
59-
flmn[n_start_ind:],
60-
Delta,
61-
Delta[:, :, L - 1 + n],
62-
(2 * np.arange(L) + 1) / (8 * np.pi**2),
63-
)
58+
flmn = np.einsum("nlm,l->nlm", flmn, (2 * np.arange(L) + 1) / (8 * np.pi**2))
59+
60+
# PRECOMPUTE TRANSFORM
61+
if DW is not None:
62+
Delta, _ = DW
63+
x = np.zeros((xnlm_size, xnlm_size, n_dim), dtype=flmn.dtype)
64+
x[m_offset:, m_offset:] = np.einsum(
65+
"nlm,lam,lan->amn", flmn[n_start_ind:], Delta, Delta[:, :, L - 1 + n]
66+
)
67+
68+
# OTF TRANSFORM
69+
else:
70+
Delta_el = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64)
71+
for el in range(L):
72+
Delta_el = recursions.risbo.compute_full(Delta_el, np.pi / 2, L, el)
73+
x[m_offset:, m_offset:] += np.einsum(
74+
"nm,am,an->amn",
75+
flmn[n_start_ind:, el],
76+
Delta_el,
77+
Delta_el[:, L - 1 + n],
78+
)
6479

6580
# APPLY SIGN FUNCTION AND PHASE SHIFT
6681
x = np.einsum("amn,m,n,a->nam", x, 1j ** (-m), 1j ** (n), np.exp(1j * m * theta0))
@@ -77,12 +92,12 @@ def inverse_transform(
7792
return np.fft.ifft2(x, axes=(0, 2), norm="forward")
7893

7994

80-
@partial(jit, static_argnums=(2, 3, 4, 5))
95+
@partial(jit, static_argnums=(1, 2, 4, 5))
8196
def inverse_transform_jax(
8297
flmn: jnp.ndarray,
83-
DW: jnp.ndarray,
8498
L: int,
8599
N: int,
100+
DW: jnp.ndarray = None,
86101
reality: bool = False,
87102
sampling: str = "mw",
88103
) -> jnp.ndarray:
@@ -91,10 +106,11 @@ def inverse_transform_jax(
91106
92107
Args:
93108
flmn (jnp.ndarray): Wigner coefficients.
94-
DW (Tuple[np.ndarray, np.ndarray]): Fourier coefficients of the reduced
95-
Wigner d-functions and the corresponding upsampled quadrature weights.
96109
L (int): Harmonic band-limit.
97110
N (int): Azimuthal band-limit.
111+
DW (Tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced
112+
Wigner d-functions and the corresponding upsampled quadrature weights.
113+
Defaults to None.
98114
reality (bool, optional): Whether the signal on the sphere is real. If so,
99115
conjugate symmetry is exploited to reduce computational costs.
100116
Defaults to False.
@@ -110,9 +126,6 @@ def inverse_transform_jax(
110126
f"Fourier-Wigner algorithm does not support {sampling} sampling."
111127
)
112128

113-
# EXTRACT VARIOUS PRECOMPUTES
114-
Delta, _ = DW
115-
116129
# INDEX VALUES
117130
n_start_ind = N - 1 if reality else 0
118131
n_dim = N if reality else 2 * N - 1
@@ -128,11 +141,29 @@ def inverse_transform_jax(
128141
# Calculate fmna = i^(n-m)\sum_L Delta^l_am Delta^l_an f^l_mn(2l+1)/(8pi^2)
129142
x = jnp.zeros((xnlm_size, xnlm_size, n_dim), dtype=jnp.complex128)
130143
flmn = jnp.einsum("nlm,l->nlm", flmn, (2 * jnp.arange(L) + 1) / (8 * jnp.pi**2))
131-
x = x.at[m_offset:, m_offset:].set(
132-
jnp.einsum(
133-
"nlm,lam,lan->amn", flmn[n_start_ind:], Delta, Delta[:, :, L - 1 + n]
144+
145+
# PRECOMPUTE TRANSFORM
146+
if DW is not None:
147+
Delta, _ = DW
148+
x = x.at[m_offset:, m_offset:].set(
149+
jnp.einsum(
150+
"nlm,lam,lan->amn", flmn[n_start_ind:], Delta, Delta[:, :, L - 1 + n]
151+
)
134152
)
135-
)
153+
154+
# OTF TRANSFORM
155+
else:
156+
Delta_el = jnp.zeros((2 * L - 1, 2 * L - 1), dtype=jnp.float64)
157+
for el in range(L):
158+
Delta_el = recursions.risbo_jax.compute_full(Delta_el, jnp.pi / 2, L, el)
159+
x = x.at[m_offset:, m_offset:].add(
160+
jnp.einsum(
161+
"nm,am,an->amn",
162+
flmn[n_start_ind:, el],
163+
Delta_el,
164+
Delta_el[:, L - 1 + n],
165+
)
166+
)
136167

137168
# APPLY SIGN FUNCTION AND PHASE SHIFT
138169
x = jnp.einsum("amn,m,n,a->nam", x, 1j ** (-m), 1j ** (n), jnp.exp(1j * m * theta0))
@@ -151,9 +182,9 @@ def inverse_transform_jax(
151182

152183
def forward_transform(
153184
f: np.ndarray,
154-
DW: np.ndarray,
155185
L: int,
156186
N: int,
187+
DW: np.ndarray = None,
157188
reality: bool = False,
158189
sampling: str = "mw",
159190
) -> np.ndarray:
@@ -162,10 +193,11 @@ def forward_transform(
162193
163194
Args:
164195
f (np.ndarray): Function sampled on the rotation group.
165-
DW (Tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced
166-
Wigner d-functions and the corresponding upsampled quadrature weights.
167196
L (int): Harmonic band-limit.
168197
N (int): Azimuthal band-limit.
198+
DW (Tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced
199+
Wigner d-functions and the corresponding upsampled quadrature weights.
200+
Defaults to None.
169201
reality (bool, optional): Whether the signal on the sphere is real. If so,
170202
conjugate symmetry is exploited to reduce computational costs.
171203
Defaults to False.
@@ -181,9 +213,6 @@ def forward_transform(
181213
f"Fourier-Wigner algorithm does not support {sampling} sampling."
182214
)
183215

184-
# EXTRACT VARIOUS PRECOMPUTES
185-
Delta, Quads = DW
186-
187216
# INDEX VALUES
188217
n_start_ind = N - 1 if reality else 0
189218
m_offset = 1 if sampling.lower() == "mwss" else 0
@@ -223,14 +252,44 @@ def forward_transform(
223252
# NB: Our convention here is conjugate to that of SSHT, in which
224253
# the weights are conjugate but applied flipped and therefore are
225254
# equivalent. To avoid flipping here we simply conjugate the weights.
226-
x = np.einsum("nbm,b->nbm", x, Quads)
227255

228-
# COMPUTE GMM BY FFT
229-
x = np.fft.fft(x, axis=1, norm="forward")
230-
x = np.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2]
256+
# PRECOMPUTE TRANSFORM
257+
if DW is not None:
258+
# EXTRACT VARIOUS PRECOMPUTES
259+
Delta, Quads = DW
231260

232-
# Calculate flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
233-
x = np.einsum("nam,lam,lan->nlm", x, Delta, Delta[:, :, L - 1 + n])
261+
# APPLY QUADRATURE
262+
x = np.einsum("nbm,b->nbm", x, Quads)
263+
264+
# COMPUTE GMM BY FFT
265+
x = np.fft.fft(x, axis=1, norm="forward")
266+
x = np.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2]
267+
268+
# CALCULATE flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
269+
x = np.einsum("nam,lam,lan->nlm", x, Delta, Delta[:, :, L - 1 + n])
270+
271+
# OTF TRANSFORM
272+
else:
273+
# COMPUTE QUADRATURE WEIGHTS
274+
Quads = np.zeros(4 * L - 3, dtype=np.complex128)
275+
for mm in range(-2 * (L - 1), 2 * (L - 1) + 1):
276+
Quads[mm + 2 * (L - 1)] = quadrature.mw_weights(-mm)
277+
Quads = np.fft.ifft(np.fft.ifftshift(Quads), norm="forward")
278+
279+
# APPLY QUADRATURE
280+
x = np.einsum("nbm,b->nbm", x, Quads)
281+
282+
# COMPUTE GMM BY FFT
283+
x = np.fft.fft(x, axis=1, norm="forward")
284+
x = np.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2]
285+
286+
# CALCULATE flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
287+
Delta_el = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64)
288+
xx = np.zeros((x.shape[0], L, x.shape[-1]), dtype=x.dtype)
289+
for el in range(L):
290+
Delta_el = recursions.risbo.compute_full(Delta_el, np.pi / 2, L, el)
291+
xx[:, el] = np.einsum("nam,am,an->nm", x, Delta_el, Delta_el[:, L - 1 + n])
292+
x = xx
234293
x = np.einsum("nbm,m,n->nbm", x, 1j ** (m), 1j ** (-n))
235294

236295
# SYMMETRY REFLECT FOR N < 0
@@ -246,12 +305,12 @@ def forward_transform(
246305
return x * (2.0 * np.pi) ** 2
247306

248307

249-
@partial(jit, static_argnums=(2, 3, 4, 5))
308+
@partial(jit, static_argnums=(1, 2, 4, 5))
250309
def forward_transform_jax(
251310
f: jnp.ndarray,
252-
DW: jnp.ndarray,
253311
L: int,
254312
N: int,
313+
DW: jnp.ndarray = None,
255314
reality: bool = False,
256315
sampling: str = "mw",
257316
) -> jnp.ndarray:
@@ -260,10 +319,11 @@ def forward_transform_jax(
260319
261320
Args:
262321
f (jnp.ndarray): Function sampled on the rotation group.
263-
DW (Tuple[jnp.ndarray, jnp.ndarray]): Fourier coefficients of the reduced
264-
Wigner d-functions and the corresponding upsampled quadrature weights.
265322
L (int): Harmonic band-limit.
266323
N (int): Azimuthal band-limit.
324+
DW (Tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced
325+
Wigner d-functions and the corresponding upsampled quadrature weights.
326+
Defaults to None.
267327
reality (bool, optional): Whether the signal on the sphere is real. If so,
268328
conjugate symmetry is exploited to reduce computational costs.
269329
Defaults to False.
@@ -279,9 +339,6 @@ def forward_transform_jax(
279339
f"Fourier-Wigner algorithm does not support {sampling} sampling."
280340
)
281341

282-
# EXTRACT VARIOUS PRECOMPUTES
283-
Delta, Quads = DW
284-
285342
# INDEX VALUES
286343
n_start_ind = N - 1 if reality else 0
287344
m_offset = 1 if sampling.lower() == "mwss" else 0
@@ -321,14 +378,45 @@ def forward_transform_jax(
321378
# NB: Our convention here is conjugate to that of SSHT, in which
322379
# the weights are conjugate but applied flipped and therefore are
323380
# equivalent. To avoid flipping here we simply conjugate the weights.
324-
x = jnp.einsum("nbm,b->nbm", x, Quads)
325381

326-
# COMPUTE GMM BY FFT
327-
x = jnp.fft.fft(x, axis=1, norm="forward")
328-
x = jnp.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2]
382+
# PRECOMPUTE TRANSFORM
383+
if DW is not None:
384+
# EXTRACT VARIOUS PRECOMPUTES
385+
Delta, Quads = DW
386+
387+
# APPLY QUADRATURE
388+
x = jnp.einsum("nbm,b->nbm", x, Quads)
389+
390+
# COMPUTE GMM BY FFT
391+
x = jnp.fft.fft(x, axis=1, norm="forward")
392+
x = jnp.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2]
393+
394+
# Calculate flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
395+
x = jnp.einsum("nam,lam,lan->nlm", x, Delta, Delta[:, :, L - 1 + n])
396+
397+
else:
398+
Quads = jnp.zeros(4 * L - 3, dtype=jnp.complex128)
399+
for mm in range(-2 * (L - 1), 2 * (L - 1) + 1):
400+
Quads = Quads.at[mm + 2 * (L - 1)].set(quadrature_jax.mw_weights(-mm))
401+
Quads = jnp.fft.ifft(jnp.fft.ifftshift(Quads), norm="forward")
402+
403+
# APPLY QUADRATURE
404+
x = jnp.einsum("nbm,b->nbm", x, Quads)
405+
406+
# COMPUTE GMM BY FFT
407+
x = jnp.fft.fft(x, axis=1, norm="forward")
408+
x = jnp.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2]
409+
410+
# CALCULATE flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
411+
Delta_el = jnp.zeros((2 * L - 1, 2 * L - 1), dtype=jnp.float64)
412+
xx = jnp.zeros((x.shape[0], L, x.shape[-1]), dtype=x.dtype)
413+
for el in range(L):
414+
Delta_el = recursions.risbo_jax.compute_full(Delta_el, jnp.pi / 2, L, el)
415+
xx = xx.at[:, el].set(
416+
jnp.einsum("nam,am,an->nm", x, Delta_el, Delta_el[:, L - 1 + n])
417+
)
418+
x = xx
329419

330-
# Calculate flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
331-
x = jnp.einsum("nam,lam,lan->nlm", x, Delta, Delta[:, :, L - 1 + n])
332420
x = jnp.einsum("nbm,m,n->nbm", x, 1j ** (m), 1j ** (-n))
333421

334422
# SYMMETRY REFLECT FOR N < 0

0 commit comments

Comments
 (0)