Skip to content

Commit 868ec4e

Browse files
committed
Refactor to remove repeated code
1 parent 9b42bf4 commit 868ec4e

File tree

1 file changed

+30
-39
lines changed

1 file changed

+30
-39
lines changed

s2fft/precompute_transforms/fourier_wigner.py

Lines changed: 30 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -255,43 +255,38 @@ def forward_transform(
255255
# the weights are conjugate but applied flipped and therefore are
256256
# equivalent. To avoid flipping here we simply conjugate the weights.
257257

258-
# PRECOMPUTE TRANSFORM
259258
if precomps is not None:
260-
# EXTRACT VARIOUS PRECOMPUTES
259+
# PRECOMPUTE TRANSFORM
261260
delta, quads = precomps
262-
263-
# APPLY QUADRATURE
264-
x = np.einsum("nbm,b->nbm", x, quads)
265-
266-
# COMPUTE GMM BY FFT
267-
x = np.fft.fft(x, axis=1, norm="forward")
268-
x = np.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2]
269-
270-
# CALCULATE flmn = i^(n-m)\sum_t delta^l_tm delta^l_tn G_mnt
271-
x = np.einsum("nam,lam,lan->nlm", x, delta, delta[:, :, L - 1 + n])
272-
273-
# OTF TRANSFORM
274261
else:
262+
# OTF TRANSFORM
263+
delta = None
275264
# COMPUTE QUADRATURE WEIGHTS
276265
quads = np.zeros(4 * L - 3, dtype=np.complex128)
277266
for mm in range(-2 * (L - 1), 2 * (L - 1) + 1):
278267
quads[mm + 2 * (L - 1)] = quadrature.mw_weights(-mm)
279268
quads = np.fft.ifft(np.fft.ifftshift(quads), norm="forward")
280269

281-
# APPLY QUADRATURE
282-
x = np.einsum("nbm,b->nbm", x, quads)
270+
# APPLY QUADRATURE
271+
x = np.einsum("nbm,b->nbm", x, quads)
283272

284-
# COMPUTE GMM BY FFT
285-
x = np.fft.fft(x, axis=1, norm="forward")
286-
x = np.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2]
273+
# COMPUTE GMM BY FFT
274+
x = np.fft.fft(x, axis=1, norm="forward")
275+
x = np.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2]
287276

288-
# CALCULATE flmn = i^(n-m)\sum_t delta^l_tm delta^l_tn G_mnt
277+
# CALCULATE flmn = i^(n-m)\sum_t delta^l_tm delta^l_tn G_mnt
278+
if delta is not None:
279+
# PRECOMPUTE TRANSFORM
280+
x = np.einsum("nam,lam,lan->nlm", x, delta, delta[:, :, L - 1 + n])
281+
else:
282+
# OTF TRANSFORM
289283
delta_el = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64)
290284
xx = np.zeros((x.shape[0], L, x.shape[-1]), dtype=x.dtype)
291285
for el in range(L):
292286
delta_el = recursions.risbo.compute_full(delta_el, np.pi / 2, L, el)
293287
xx[:, el] = np.einsum("nam,am,an->nm", x, delta_el, delta_el[:, L - 1 + n])
294288
x = xx
289+
295290
x = np.einsum("nbm,m,n->nbm", x, 1j ** (m), 1j ** (-n))
296291

297292
# SYMMETRY REFLECT FOR N < 0
@@ -381,35 +376,31 @@ def forward_transform_jax(
381376
# the weights are conjugate but applied flipped and therefore are
382377
# equivalent. To avoid flipping here we simply conjugate the weights.
383378

384-
# PRECOMPUTE TRANSFORM
385379
if precomps is not None:
386-
# EXTRACT VARIOUS PRECOMPUTES
380+
# PRECOMPUTE TRANSFORM
387381
delta, quads = precomps
388-
389-
# APPLY QUADRATURE
390-
x = jnp.einsum("nbm,b->nbm", x, quads)
391-
392-
# COMPUTE GMM BY FFT
393-
x = jnp.fft.fft(x, axis=1, norm="forward")
394-
x = jnp.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2]
395-
396-
# Calculate flmn = i^(n-m)\sum_t delta^l_tm delta^l_tn G_mnt
397-
x = jnp.einsum("nam,lam,lan->nlm", x, delta, delta[:, :, L - 1 + n])
398-
399382
else:
383+
# OTF TRANSFORM
384+
delta = None
385+
# COMPUTE QUADRATURE WEIGHTS
400386
quads = jnp.zeros(4 * L - 3, dtype=jnp.complex128)
401387
for mm in range(-2 * (L - 1), 2 * (L - 1) + 1):
402388
quads = quads.at[mm + 2 * (L - 1)].set(quadrature_jax.mw_weights(-mm))
403389
quads = jnp.fft.ifft(jnp.fft.ifftshift(quads), norm="forward")
404390

405-
# APPLY QUADRATURE
406-
x = jnp.einsum("nbm,b->nbm", x, quads)
391+
# APPLY QUADRATURE
392+
x = jnp.einsum("nbm,b->nbm", x, quads)
407393

408-
# COMPUTE GMM BY FFT
409-
x = jnp.fft.fft(x, axis=1, norm="forward")
410-
x = jnp.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2]
394+
# COMPUTE GMM BY FFT
395+
x = jnp.fft.fft(x, axis=1, norm="forward")
396+
x = jnp.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2]
411397

412-
# CALCULATE flmn = i^(n-m)\sum_t delta^l_tm delta^l_tn G_mnt
398+
# Calculate flmn = i^(n-m)\sum_t delta^l_tm delta^l_tn G_mnt
399+
if delta is not None:
400+
# PRECOMPUTE TRANSFORM
401+
x = jnp.einsum("nam,lam,lan->nlm", x, delta, delta[:, :, L - 1 + n])
402+
else:
403+
# OTF TRANSFORM
413404
delta_el = jnp.zeros((2 * L - 1, 2 * L - 1), dtype=jnp.float64)
414405
xx = jnp.zeros((x.shape[0], L, x.shape[-1]), dtype=x.dtype)
415406
for el in range(L):

0 commit comments

Comments
 (0)