44import numpy as np
55from jax import jit
66
7+ from s2fft import recursions
8+ from s2fft .utils import quadrature , quadrature_jax
9+
710
811def inverse_transform (
912 flmn : np .ndarray ,
10- DW : np .ndarray ,
1113 L : int ,
1214 N : int ,
15+ DW : np .ndarray = None ,
1316 reality : bool = False ,
1417 sampling : str = "mw" ,
1518) -> np .ndarray :
@@ -18,10 +21,11 @@ def inverse_transform(
1821
1922 Args:
2023 flmn (np.ndarray): Wigner coefficients.
21- DW (Tuple[np.ndarray, np.ndarray]): Fourier coefficients of the reduced
22- Wigner d-functions and the corresponding upsampled quadrature weights.
2324 L (int): Harmonic band-limit.
2425 N (int): Azimuthal band-limit.
26+ DW (Tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced
27+ Wigner d-functions and the corresponding upsampled quadrature weights.
28+ Defaults to None.
2529 reality (bool, optional): Whether the signal on the sphere is real. If so,
2630 conjugate symmetry is exploited to reduce computational costs.
2731 Defaults to False.
@@ -37,9 +41,6 @@ def inverse_transform(
3741 f"Fourier-Wigner algorithm does not support { sampling } sampling."
3842 )
3943
40- # EXTRACT VARIOUS PRECOMPUTES
41- Delta , _ = DW
42-
4344 # INDEX VALUES
4445 n_start_ind = N - 1 if reality else 0
4546 n_dim = N if reality else 2 * N - 1
@@ -54,13 +55,27 @@ def inverse_transform(
5455
5556 # Calculate fmna = i^(n-m)\sum_L Delta^l_am Delta^l_an f^l_mn(2l+1)/(8pi^2)
5657 x = np .zeros ((xnlm_size , xnlm_size , n_dim ), dtype = flmn .dtype )
57- x [m_offset :, m_offset :] = np .einsum (
58- "nlm,lam,lan,l->amn" ,
59- flmn [n_start_ind :],
60- Delta ,
61- Delta [:, :, L - 1 + n ],
62- (2 * np .arange (L ) + 1 ) / (8 * np .pi ** 2 ),
63- )
58+ flmn = np .einsum ("nlm,l->nlm" , flmn , (2 * np .arange (L ) + 1 ) / (8 * np .pi ** 2 ))
59+
60+ # PRECOMPUTE TRANSFORM
61+ if DW is not None :
62+ Delta , _ = DW
63+ x = np .zeros ((xnlm_size , xnlm_size , n_dim ), dtype = flmn .dtype )
64+ x [m_offset :, m_offset :] = np .einsum (
65+ "nlm,lam,lan->amn" , flmn [n_start_ind :], Delta , Delta [:, :, L - 1 + n ]
66+ )
67+
68+ # OTF TRANSFORM
69+ else :
70+ Delta_el = np .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = np .float64 )
71+ for el in range (L ):
72+ Delta_el = recursions .risbo .compute_full (Delta_el , np .pi / 2 , L , el )
73+ x [m_offset :, m_offset :] += np .einsum (
74+ "nm,am,an->amn" ,
75+ flmn [n_start_ind :, el ],
76+ Delta_el ,
77+ Delta_el [:, L - 1 + n ],
78+ )
6479
6580 # APPLY SIGN FUNCTION AND PHASE SHIFT
6681 x = np .einsum ("amn,m,n,a->nam" , x , 1j ** (- m ), 1j ** (n ), np .exp (1j * m * theta0 ))
@@ -77,12 +92,12 @@ def inverse_transform(
7792 return np .fft .ifft2 (x , axes = (0 , 2 ), norm = "forward" )
7893
7994
80- @partial (jit , static_argnums = (2 , 3 , 4 , 5 ))
95+ @partial (jit , static_argnums = (1 , 2 , 4 , 5 ))
8196def inverse_transform_jax (
8297 flmn : jnp .ndarray ,
83- DW : jnp .ndarray ,
8498 L : int ,
8599 N : int ,
100+ DW : jnp .ndarray = None ,
86101 reality : bool = False ,
87102 sampling : str = "mw" ,
88103) -> jnp .ndarray :
@@ -91,10 +106,11 @@ def inverse_transform_jax(
91106
92107 Args:
93108 flmn (jnp.ndarray): Wigner coefficients.
94- DW (Tuple[np.ndarray, np.ndarray]): Fourier coefficients of the reduced
95- Wigner d-functions and the corresponding upsampled quadrature weights.
96109 L (int): Harmonic band-limit.
97110 N (int): Azimuthal band-limit.
111+ DW (Tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced
112+ Wigner d-functions and the corresponding upsampled quadrature weights.
113+ Defaults to None.
98114 reality (bool, optional): Whether the signal on the sphere is real. If so,
99115 conjugate symmetry is exploited to reduce computational costs.
100116 Defaults to False.
@@ -110,9 +126,6 @@ def inverse_transform_jax(
110126 f"Fourier-Wigner algorithm does not support { sampling } sampling."
111127 )
112128
113- # EXTRACT VARIOUS PRECOMPUTES
114- Delta , _ = DW
115-
116129 # INDEX VALUES
117130 n_start_ind = N - 1 if reality else 0
118131 n_dim = N if reality else 2 * N - 1
@@ -128,11 +141,29 @@ def inverse_transform_jax(
128141 # Calculate fmna = i^(n-m)\sum_L Delta^l_am Delta^l_an f^l_mn(2l+1)/(8pi^2)
129142 x = jnp .zeros ((xnlm_size , xnlm_size , n_dim ), dtype = jnp .complex128 )
130143 flmn = jnp .einsum ("nlm,l->nlm" , flmn , (2 * jnp .arange (L ) + 1 ) / (8 * jnp .pi ** 2 ))
131- x = x .at [m_offset :, m_offset :].set (
132- jnp .einsum (
133- "nlm,lam,lan->amn" , flmn [n_start_ind :], Delta , Delta [:, :, L - 1 + n ]
144+
145+ # PRECOMPUTE TRANSFORM
146+ if DW is not None :
147+ Delta , _ = DW
148+ x = x .at [m_offset :, m_offset :].set (
149+ jnp .einsum (
150+ "nlm,lam,lan->amn" , flmn [n_start_ind :], Delta , Delta [:, :, L - 1 + n ]
151+ )
134152 )
135- )
153+
154+ # OTF TRANSFORM
155+ else :
156+ Delta_el = jnp .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = jnp .float64 )
157+ for el in range (L ):
158+ Delta_el = recursions .risbo_jax .compute_full (Delta_el , jnp .pi / 2 , L , el )
159+ x = x .at [m_offset :, m_offset :].add (
160+ jnp .einsum (
161+ "nm,am,an->amn" ,
162+ flmn [n_start_ind :, el ],
163+ Delta_el ,
164+ Delta_el [:, L - 1 + n ],
165+ )
166+ )
136167
137168 # APPLY SIGN FUNCTION AND PHASE SHIFT
138169 x = jnp .einsum ("amn,m,n,a->nam" , x , 1j ** (- m ), 1j ** (n ), jnp .exp (1j * m * theta0 ))
@@ -151,9 +182,9 @@ def inverse_transform_jax(
151182
152183def forward_transform (
153184 f : np .ndarray ,
154- DW : np .ndarray ,
155185 L : int ,
156186 N : int ,
187+ DW : np .ndarray = None ,
157188 reality : bool = False ,
158189 sampling : str = "mw" ,
159190) -> np .ndarray :
@@ -162,10 +193,11 @@ def forward_transform(
162193
163194 Args:
164195 f (np.ndarray): Function sampled on the rotation group.
165- DW (Tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced
166- Wigner d-functions and the corresponding upsampled quadrature weights.
167196 L (int): Harmonic band-limit.
168197 N (int): Azimuthal band-limit.
198+ DW (Tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced
199+ Wigner d-functions and the corresponding upsampled quadrature weights.
200+ Defaults to None.
169201 reality (bool, optional): Whether the signal on the sphere is real. If so,
170202 conjugate symmetry is exploited to reduce computational costs.
171203 Defaults to False.
@@ -181,9 +213,6 @@ def forward_transform(
181213 f"Fourier-Wigner algorithm does not support { sampling } sampling."
182214 )
183215
184- # EXTRACT VARIOUS PRECOMPUTES
185- Delta , Quads = DW
186-
187216 # INDEX VALUES
188217 n_start_ind = N - 1 if reality else 0
189218 m_offset = 1 if sampling .lower () == "mwss" else 0
@@ -223,14 +252,44 @@ def forward_transform(
223252 # NB: Our convention here is conjugate to that of SSHT, in which
224253 # the weights are conjugate but applied flipped and therefore are
225254 # equivalent. To avoid flipping here we simply conjugate the weights.
226- x = np .einsum ("nbm,b->nbm" , x , Quads )
227255
228- # COMPUTE GMM BY FFT
229- x = np .fft .fft (x , axis = 1 , norm = "forward" )
230- x = np .fft .fftshift (x , axes = 1 )[:, L - 1 : 3 * L - 2 ]
256+ # PRECOMPUTE TRANSFORM
257+ if DW is not None :
258+ # EXTRACT VARIOUS PRECOMPUTES
259+ Delta , Quads = DW
231260
232- # Calculate flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
233- x = np .einsum ("nam,lam,lan->nlm" , x , Delta , Delta [:, :, L - 1 + n ])
261+ # APPLY QUADRATURE
262+ x = np .einsum ("nbm,b->nbm" , x , Quads )
263+
264+ # COMPUTE GMM BY FFT
265+ x = np .fft .fft (x , axis = 1 , norm = "forward" )
266+ x = np .fft .fftshift (x , axes = 1 )[:, L - 1 : 3 * L - 2 ]
267+
268+ # CALCULATE flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
269+ x = np .einsum ("nam,lam,lan->nlm" , x , Delta , Delta [:, :, L - 1 + n ])
270+
271+ # OTF TRANSFORM
272+ else :
273+ # COMPUTE QUADRATURE WEIGHTS
274+ Quads = np .zeros (4 * L - 3 , dtype = np .complex128 )
275+ for mm in range (- 2 * (L - 1 ), 2 * (L - 1 ) + 1 ):
276+ Quads [mm + 2 * (L - 1 )] = quadrature .mw_weights (- mm )
277+ Quads = np .fft .ifft (np .fft .ifftshift (Quads ), norm = "forward" )
278+
279+ # APPLY QUADRATURE
280+ x = np .einsum ("nbm,b->nbm" , x , Quads )
281+
282+ # COMPUTE GMM BY FFT
283+ x = np .fft .fft (x , axis = 1 , norm = "forward" )
284+ x = np .fft .fftshift (x , axes = 1 )[:, L - 1 : 3 * L - 2 ]
285+
286+ # CALCULATE flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
287+ Delta_el = np .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = np .float64 )
288+ xx = np .zeros ((x .shape [0 ], L , x .shape [- 1 ]), dtype = x .dtype )
289+ for el in range (L ):
290+ Delta_el = recursions .risbo .compute_full (Delta_el , np .pi / 2 , L , el )
291+ xx [:, el ] = np .einsum ("nam,am,an->nm" , x , Delta_el , Delta_el [:, L - 1 + n ])
292+ x = xx
234293 x = np .einsum ("nbm,m,n->nbm" , x , 1j ** (m ), 1j ** (- n ))
235294
236295 # SYMMETRY REFLECT FOR N < 0
@@ -246,12 +305,12 @@ def forward_transform(
246305 return x * (2.0 * np .pi ) ** 2
247306
248307
249- @partial (jit , static_argnums = (2 , 3 , 4 , 5 ))
308+ @partial (jit , static_argnums = (1 , 2 , 4 , 5 ))
250309def forward_transform_jax (
251310 f : jnp .ndarray ,
252- DW : jnp .ndarray ,
253311 L : int ,
254312 N : int ,
313+ DW : jnp .ndarray = None ,
255314 reality : bool = False ,
256315 sampling : str = "mw" ,
257316) -> jnp .ndarray :
@@ -260,10 +319,11 @@ def forward_transform_jax(
260319
261320 Args:
262321 f (jnp.ndarray): Function sampled on the rotation group.
263- DW (Tuple[jnp.ndarray, jnp.ndarray]): Fourier coefficients of the reduced
264- Wigner d-functions and the corresponding upsampled quadrature weights.
265322 L (int): Harmonic band-limit.
266323 N (int): Azimuthal band-limit.
324+ DW (Tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced
325+ Wigner d-functions and the corresponding upsampled quadrature weights.
326+ Defaults to None.
267327 reality (bool, optional): Whether the signal on the sphere is real. If so,
268328 conjugate symmetry is exploited to reduce computational costs.
269329 Defaults to False.
@@ -279,9 +339,6 @@ def forward_transform_jax(
279339 f"Fourier-Wigner algorithm does not support { sampling } sampling."
280340 )
281341
282- # EXTRACT VARIOUS PRECOMPUTES
283- Delta , Quads = DW
284-
285342 # INDEX VALUES
286343 n_start_ind = N - 1 if reality else 0
287344 m_offset = 1 if sampling .lower () == "mwss" else 0
@@ -321,14 +378,45 @@ def forward_transform_jax(
321378 # NB: Our convention here is conjugate to that of SSHT, in which
322379 # the weights are conjugate but applied flipped and therefore are
323380 # equivalent. To avoid flipping here we simply conjugate the weights.
324- x = jnp .einsum ("nbm,b->nbm" , x , Quads )
325381
326- # COMPUTE GMM BY FFT
327- x = jnp .fft .fft (x , axis = 1 , norm = "forward" )
328- x = jnp .fft .fftshift (x , axes = 1 )[:, L - 1 : 3 * L - 2 ]
382+ # PRECOMPUTE TRANSFORM
383+ if DW is not None :
384+ # EXTRACT VARIOUS PRECOMPUTES
385+ Delta , Quads = DW
386+
387+ # APPLY QUADRATURE
388+ x = jnp .einsum ("nbm,b->nbm" , x , Quads )
389+
390+ # COMPUTE GMM BY FFT
391+ x = jnp .fft .fft (x , axis = 1 , norm = "forward" )
392+ x = jnp .fft .fftshift (x , axes = 1 )[:, L - 1 : 3 * L - 2 ]
393+
394+ # Calculate flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
395+ x = jnp .einsum ("nam,lam,lan->nlm" , x , Delta , Delta [:, :, L - 1 + n ])
396+
397+ else :
398+ Quads = jnp .zeros (4 * L - 3 , dtype = jnp .complex128 )
399+ for mm in range (- 2 * (L - 1 ), 2 * (L - 1 ) + 1 ):
400+ Quads = Quads .at [mm + 2 * (L - 1 )].set (quadrature_jax .mw_weights (- mm ))
401+ Quads = jnp .fft .ifft (jnp .fft .ifftshift (Quads ), norm = "forward" )
402+
403+ # APPLY QUADRATURE
404+ x = jnp .einsum ("nbm,b->nbm" , x , Quads )
405+
406+ # COMPUTE GMM BY FFT
407+ x = jnp .fft .fft (x , axis = 1 , norm = "forward" )
408+ x = jnp .fft .fftshift (x , axes = 1 )[:, L - 1 : 3 * L - 2 ]
409+
410+ # CALCULATE flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
411+ Delta_el = jnp .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = jnp .float64 )
412+ xx = jnp .zeros ((x .shape [0 ], L , x .shape [- 1 ]), dtype = x .dtype )
413+ for el in range (L ):
414+ Delta_el = recursions .risbo_jax .compute_full (Delta_el , jnp .pi / 2 , L , el )
415+ xx = xx .at [:, el ].set (
416+ jnp .einsum ("nam,am,an->nm" , x , Delta_el , Delta_el [:, L - 1 + n ])
417+ )
418+ x = xx
329419
330- # Calculate flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
331- x = jnp .einsum ("nam,lam,lan->nlm" , x , Delta , Delta [:, :, L - 1 + n ])
332420 x = jnp .einsum ("nbm,m,n->nbm" , x , 1j ** (m ), 1j ** (- n ))
333421
334422 # SYMMETRY REFLECT FOR N < 0
0 commit comments