77
88def inverse_transform (
99 flmn : np .ndarray ,
10- delta : np .ndarray ,
10+ DW : np .ndarray ,
1111 L : int ,
1212 N : int ,
1313 reality : bool = False ,
@@ -18,7 +18,7 @@ def inverse_transform(
1818
1919 Args:
2020 flmn (np.ndarray): Wigner coefficients.
21- delta (Tuple[np.ndarray, np.ndarray]): Fourier coefficients of the reduced
21+ DW (Tuple[np.ndarray, np.ndarray], optional ): Fourier coefficients of the reduced
2222 Wigner d-functions and the corresponding upsampled quadrature weights.
2323 L (int): Harmonic band-limit.
2424 N (int): Azimuthal band-limit.
@@ -32,6 +32,14 @@ def inverse_transform(
3232 np.ndarray: Pixel-space function sampled on the rotation group.
3333
3434 """
35+ if sampling .lower () not in ["mw" , "mwss" ]:
36+ raise ValueError (
37+ f"Fourier-Wigner algorithm does not support { sampling } sampling."
38+ )
39+
40+ # EXTRACT VARIOUS PRECOMPUTES
41+ Delta , _ = DW
42+
3543 # INDEX VALUES
3644 n_start_ind = N - 1 if reality else 0
3745 n_dim = N if reality else 2 * N - 1
@@ -44,13 +52,13 @@ def inverse_transform(
4452 m = np .arange (- L + 1 - m_offset , L )
4553 n = np .arange (n_start_ind - N + 1 , N )
4654
47- # Calculate fmna = i^(n-m)\sum_L delta ^l_am delta ^l_an f^l_mn(2l+1)/(8pi^2)
55+ # Calculate fmna = i^(n-m)\sum_L Delta ^l_am Delta ^l_an f^l_mn(2l+1)/(8pi^2)
4856 x = np .zeros ((xnlm_size , xnlm_size , n_dim ), dtype = flmn .dtype )
4957 x [m_offset :, m_offset :] = np .einsum (
5058 "nlm,lam,lan,l->amn" ,
5159 flmn [n_start_ind :],
52- delta [ 0 ] ,
53- delta [ 0 ] [:, :, L - 1 + n ],
60+ Delta ,
61+ Delta [:, :, L - 1 + n ],
5462 (2 * np .arange (L ) + 1 ) / (8 * np .pi ** 2 ),
5563 )
5664
@@ -72,7 +80,7 @@ def inverse_transform(
7280@partial (jit , static_argnums = (2 , 3 , 4 , 5 ))
7381def inverse_transform_jax (
7482 flmn : jnp .ndarray ,
75- delta : jnp .ndarray ,
83+ DW : jnp .ndarray ,
7684 L : int ,
7785 N : int ,
7886 reality : bool = False ,
@@ -83,7 +91,7 @@ def inverse_transform_jax(
8391
8492 Args:
8593 flmn (jnp.ndarray): Wigner coefficients.
86- delta (Tuple[jnp .ndarray, jnp .ndarray]): Fourier coefficients of the reduced
94+ DW (Tuple[np .ndarray, np .ndarray]): Fourier coefficients of the reduced
8795 Wigner d-functions and the corresponding upsampled quadrature weights.
8896 L (int): Harmonic band-limit.
8997 N (int): Azimuthal band-limit.
@@ -97,6 +105,14 @@ def inverse_transform_jax(
97105 jnp.ndarray: Pixel-space function sampled on the rotation group.
98106
99107 """
108+ if sampling .lower () not in ["mw" , "mwss" ]:
109+ raise ValueError (
110+ f"Fourier-Wigner algorithm does not support { sampling } sampling."
111+ )
112+
113+ # EXTRACT VARIOUS PRECOMPUTES
114+ Delta , _ = DW
115+
100116 # INDEX VALUES
101117 n_start_ind = N - 1 if reality else 0
102118 n_dim = N if reality else 2 * N - 1
@@ -109,17 +125,15 @@ def inverse_transform_jax(
109125 m = jnp .arange (- L + 1 - m_offset , L )
110126 n = jnp .arange (n_start_ind - N + 1 , N )
111127
112- # Calculate fmna = i^(n-m)\sum_L delta^l_am delta^l_an f^l_mn(2l+1)/(8pi^2)
113- x = jnp .zeros ((xnlm_size , xnlm_size , n_dim ), dtype = flmn .dtype )
128+ # Calculate fmna = i^(n-m)\sum_L Delta^l_am Delta^l_an f^l_mn(2l+1)/(8pi^2)
129+ x = jnp .zeros ((xnlm_size , xnlm_size , n_dim ), dtype = jnp .complex128 )
130+ flmn = jnp .einsum ("nlm,l->nlm" , flmn , (2 * jnp .arange (L ) + 1 ) / (8 * jnp .pi ** 2 ))
114131 x = x .at [m_offset :, m_offset :].set (
115132 jnp .einsum (
116- "nlm,lam,lan,l->amn" ,
117- flmn [n_start_ind :],
118- delta [0 ],
119- delta [0 ][:, :, L - 1 + n ],
120- (2 * jnp .arange (L ) + 1 ) / (8 * jnp .pi ** 2 ),
133+ "nlm,lam,lan->amn" , flmn [n_start_ind :], Delta , Delta [:, :, L - 1 + n ]
121134 )
122135 )
136+
123137 # APPLY SIGN FUNCTION AND PHASE SHIFT
124138 x = jnp .einsum ("amn,m,n,a->nam" , x , 1j ** (- m ), 1j ** (n ), jnp .exp (1j * m * theta0 ))
125139
@@ -136,14 +150,19 @@ def inverse_transform_jax(
136150
137151
138152def forward_transform (
139- f : np .ndarray , delta : np .ndarray , L : int , N : int , reality : bool , sampling : str
153+ f : np .ndarray ,
154+ DW : np .ndarray ,
155+ L : int ,
156+ N : int ,
157+ reality : bool = False ,
158+ sampling : str = "mw" ,
140159) -> np .ndarray :
141160 """
142161 Computes the forward Wigner transform using the Fourier decomposition algorithm.
143162
144163 Args:
145164 f (np.ndarray): Function sampled on the rotation group.
146- delta (Tuple[np.ndarray, np.ndarray]): Fourier coefficients of the reduced
165+ DW (Tuple[np.ndarray, np.ndarray], optional ): Fourier coefficients of the reduced
147166 Wigner d-functions and the corresponding upsampled quadrature weights.
148167 L (int): Harmonic band-limit.
149168 N (int): Azimuthal band-limit.
@@ -157,6 +176,14 @@ def forward_transform(
157176 np.ndarray: Wigner coefficients of function f.
158177
159178 """
179+ if sampling .lower () not in ["mw" , "mwss" ]:
180+ raise ValueError (
181+ f"Fourier-Wigner algorithm does not support { sampling } sampling."
182+ )
183+
184+ # EXTRACT VARIOUS PRECOMPUTES
185+ Delta , Quads = DW
186+
160187 # INDEX VALUES
161188 n_start_ind = N - 1 if reality else 0
162189 m_offset = 1 if sampling .lower () == "mwss" else 0
@@ -193,14 +220,17 @@ def forward_transform(
193220 x = np .fft .ifft (x , axis = 1 , norm = "forward" )
194221
195222 # PERFORM QUADRATURE CONVOLUTION AS FFT REWEIGHTING IN REAL SPACE
196- x = np .einsum ("nbm,b->nbm" , x , delta [1 ])
223+ # NB: Our convention here is conjugate to that of SSHT, in which
224+ # the weights are conjugate but applied flipped and therefore are
225+ # equivalent. To avoid flipping here he simply conjugate the weights.
226+ x = np .einsum ("nbm,b->nbm" , x , Quads )
197227
198228 # COMPUTE GMM BY FFT
199229 x = np .fft .fft (x , axis = 1 , norm = "forward" )
200230 x = np .fft .fftshift (x , axes = 1 )[:, L - 1 : 3 * L - 2 ]
201231
202- # Calculate flmn = i^(n-m)\sum_t delta ^l_tm delta ^l_tn G_mnt
203- x = np .einsum ("nam,lam,lan->nlm" , x , delta [ 0 ], delta [ 0 ] [:, :, L - 1 + n ])
232+ # Calculate flmn = i^(n-m)\sum_t Delta ^l_tm Delta ^l_tn G_mnt
233+ x = np .einsum ("nam,lam,lan->nlm" , x , Delta , Delta [:, :, L - 1 + n ])
204234 x = np .einsum ("nbm,m,n->nbm" , x , 1j ** (m ), 1j ** (- n ))
205235
206236 # SYMMETRY REFLECT FOR N < 0
@@ -218,14 +248,19 @@ def forward_transform(
218248
219249@partial (jit , static_argnums = (2 , 3 , 4 , 5 ))
220250def forward_transform_jax (
221- f : jnp .ndarray , delta : jnp .ndarray , L : int , N : int , reality : bool , sampling : str
251+ f : jnp .ndarray ,
252+ DW : jnp .ndarray ,
253+ L : int ,
254+ N : int ,
255+ reality : bool = False ,
256+ sampling : str = "mw" ,
222257) -> jnp .ndarray :
223258 """
224259 Computes the forward Wigner transform using the Fourier decomposition algorithm (JAX).
225260
226261 Args:
227262 f (jnp.ndarray): Function sampled on the rotation group.
228- delta (Tuple[jnp.ndarray, jnp.ndarray]): Fourier coefficients of the reduced
263+ DW (Tuple[jnp.ndarray, jnp.ndarray]): Fourier coefficients of the reduced
229264 Wigner d-functions and the corresponding upsampled quadrature weights.
230265 L (int): Harmonic band-limit.
231266 N (int): Azimuthal band-limit.
@@ -239,6 +274,14 @@ def forward_transform_jax(
239274 jnp.ndarray: Wigner coefficients of function f.
240275
241276 """
277+ if sampling .lower () not in ["mw" , "mwss" ]:
278+ raise ValueError (
279+ f"Fourier-Wigner algorithm does not support { sampling } sampling."
280+ )
281+
282+ # EXTRACT VARIOUS PRECOMPUTES
283+ Delta , Quads = DW
284+
242285 # INDEX VALUES
243286 n_start_ind = N - 1 if reality else 0
244287 m_offset = 1 if sampling .lower () == "mwss" else 0
@@ -275,14 +318,17 @@ def forward_transform_jax(
275318 x = jnp .fft .ifft (x , axis = 1 , norm = "forward" )
276319
277320 # PERFORM QUADRATURE CONVOLUTION AS FFT REWEIGHTING IN REAL SPACE
278- x = jnp .einsum ("nbm,b->nbm" , x , delta [1 ])
321+ # NB: Our convention here is conjugate to that of SSHT, in which
322+ # the weights are conjugate but applied flipped and therefore are
323+ # equivalent. To avoid flipping here he simply conjugate the weights.
324+ x = jnp .einsum ("nbm,b->nbm" , x , Quads )
279325
280326 # COMPUTE GMM BY FFT
281327 x = jnp .fft .fft (x , axis = 1 , norm = "forward" )
282328 x = jnp .fft .fftshift (x , axes = 1 )[:, L - 1 : 3 * L - 2 ]
283329
284- # Calculate flmn = i^(n-m)\sum_t delta ^l_tm delta ^l_tn G_mnt
285- x = jnp .einsum ("nam,lam,lan->nlm" , x , delta [ 0 ], delta [ 0 ] [:, :, L - 1 + n ])
330+ # Calculate flmn = i^(n-m)\sum_t Delta ^l_tm Delta ^l_tn G_mnt
331+ x = jnp .einsum ("nam,lam,lan->nlm" , x , Delta , Delta [:, :, L - 1 + n ])
286332 x = jnp .einsum ("nbm,m,n->nbm" , x , 1j ** (m ), 1j ** (- n ))
287333
288334 # SYMMETRY REFLECT FOR N < 0
0 commit comments