Skip to content

Commit cd238cf

Browse files
authored
Merge pull request #166 from astro-informatics/feature/GPU_precompute_kernels
black format and add GPU precompute kernels
2 parents a796dfe + c91068c commit cd238cf

31 files changed

+353
-476
lines changed

s2fft/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
import logging
1212
from jax.config import config
13+
1314
if config.read("jax_enable_x64") is False:
1415
logger = logging.getLogger("s2fft")
15-
logger.warning("JAX is not using 64-bit precision. This will dramatically affect numerical precision at even moderate L.")
16+
logger.warning(
17+
"JAX is not using 64-bit precision. This will dramatically affect numerical precision at even moderate L."
18+
)

s2fft/base_transforms/spherical.py

Lines changed: 9 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -297,18 +297,15 @@ def _compute_inverse_direct(
297297
f = np.zeros(samples.f_shape(L, sampling, nside), dtype=np.complex128)
298298

299299
for t, theta in enumerate(thetas):
300-
301300
if sampling.lower() == "healpix":
302301
phis_ring = samples.phis_ring(t, nside)
303302

304303
for el in range(max(L_lower, abs(spin)), L):
305-
306304
dl = recursions.turok.compute_slice(theta, el, L, -spin, reality)
307305

308306
elfactor = np.sqrt((2 * el + 1) / (4 * np.pi))
309307

310308
for p, phi in enumerate(phis_ring):
311-
312309
if sampling.lower() != "healpix":
313310
entry = (t, p)
314311

@@ -459,22 +456,17 @@ def _compute_inverse_sov_fft(
459456
m_offset = 1 if sampling in ["mwss", "healpix"] else 0
460457

461458
for t, theta in enumerate(thetas):
462-
463459
phi_ring_offset = (
464-
samples.p2phi_ring(t, 0, nside)
465-
if sampling.lower() == "healpix"
466-
else 0
460+
samples.p2phi_ring(t, 0, nside) if sampling.lower() == "healpix" else 0
467461
)
468462

469463
for el in range(max(L_lower, abs(spin)), L):
470-
471464
dl = recursions.turok.compute_slice(theta, el, L, -spin, reality)
472465

473466
elfactor = np.sqrt((2 * el + 1) / (4 * np.pi))
474467

475468
m_start_ind = 0 if reality else -el
476469
for m in range(m_start_ind, el + 1):
477-
478470
phase_shift = (
479471
np.exp(1j * m * phi_ring_offset)
480472
if sampling.lower() == "healpix"
@@ -506,9 +498,7 @@ def _compute_inverse_sov_fft(
506498
norm="forward",
507499
)
508500
else:
509-
f = np.fft.ifft(
510-
np.fft.ifftshift(ftm, axes=1), axis=1, norm="forward"
511-
)
501+
f = np.fft.ifft(np.fft.ifftshift(ftm, axes=1), axis=1, norm="forward")
512502

513503
return f
514504

@@ -554,24 +544,17 @@ def _compute_inverse_sov_fft_vectorized(
554544
m_offset = 1 if sampling in ["mwss", "healpix"] else 0
555545

556546
for t, theta in enumerate(thetas):
557-
558547
phase_shift = (
559548
samples.ring_phase_shift_hp(L, t, nside, False, reality)
560549
if sampling.lower() == "healpix"
561550
else 1.0
562551
)
563552

564553
for el in range(max(L_lower, abs(spin)), L):
565-
566554
dl = recursions.turok.compute_slice(theta, el, L, -spin, reality)
567555
elfactor = np.sqrt((2 * el + 1) / (4 * np.pi))
568556
m_start_ind = L - 1 if reality else 0
569-
val = (
570-
elfactor
571-
* dl[m_start_ind:]
572-
* flm[el, m_start_ind:]
573-
* phase_shift
574-
)
557+
val = elfactor * dl[m_start_ind:] * flm[el, m_start_ind:] * phase_shift
575558
if reality and sampling.lower() == "healpix":
576559
ftm[t, m_offset : L - 1 + m_offset] += np.flip(np.conj(val[1:]))
577560

@@ -589,9 +572,7 @@ def _compute_inverse_sov_fft_vectorized(
589572
norm="forward",
590573
)
591574
else:
592-
f = np.fft.ifft(
593-
np.fft.ifftshift(ftm, axes=1), axis=1, norm="forward"
594-
)
575+
f = np.fft.ifft(np.fft.ifftshift(ftm, axes=1), axis=1, norm="forward")
595576

596577
return f
597578

@@ -641,30 +622,23 @@ def _compute_forward_direct(
641622
phis_ring = samples.phis_equiang(L, sampling)
642623

643624
for t, theta in enumerate(thetas):
644-
645625
if sampling.lower() == "healpix":
646626
phis_ring = samples.phis_ring(t, nside)
647627

648628
for el in range(max(L_lower, abs(spin)), L):
649-
650629
dl = recursions.turok.compute_slice(theta, el, L, -spin, reality)
651630

652631
elfactor = np.sqrt((2 * el + 1) / (4 * np.pi))
653632

654633
for p, phi in enumerate(phis_ring):
655-
656634
if sampling.lower() != "healpix":
657635
entry = (t, p)
658636
else:
659637
entry = samples.hp_ang2pix(nside, theta, phi)
660638

661639
if reality:
662640
flm[el, L - 1] += (
663-
weights[t]
664-
* (-1) ** spin
665-
* elfactor
666-
* dl[L - 1]
667-
* f[entry]
641+
weights[t] * (-1) ** spin * elfactor * dl[L - 1] * f[entry]
668642
) # m = 0
669643
for m in range(1, el + 1):
670644
val = (
@@ -738,7 +712,6 @@ def _compute_forward_sov(
738712

739713
ftm = np.zeros((len(thetas), 2 * L - 1), dtype=np.complex128)
740714
for t, theta in enumerate(thetas):
741-
742715
if sampling.lower() == "healpix":
743716
phis_ring = samples.phis_ring(t, nside)
744717

@@ -755,20 +728,14 @@ def _compute_forward_sov(
755728
flm = np.zeros(samples.flm_shape(L), dtype=np.complex128)
756729

757730
for t, theta in enumerate(thetas):
758-
759731
for el in range(max(L_lower, abs(spin)), L):
760-
761732
dl = recursions.turok.compute_slice(theta, el, L, -spin, reality)
762733

763734
elfactor = np.sqrt((2 * el + 1) / (4 * np.pi))
764735

765736
if reality:
766737
flm[el, L - 1] += (
767-
weights[t]
768-
* (-1) ** spin
769-
* elfactor
770-
* dl[L - 1]
771-
* ftm[t, L - 1]
738+
weights[t] * (-1) ** spin * elfactor * dl[L - 1] * ftm[t, L - 1]
772739
) # m = 0
773740
for m in range(1, el + 1):
774741
val = (
@@ -852,20 +819,14 @@ def _compute_forward_sov_fft(
852819
ftm_temp = ftm_temp[:, :-1]
853820
ftm[:, L - 1 + m_offset :] = ftm_temp
854821
else:
855-
ftm = np.fft.fftshift(
856-
np.fft.fft(f, axis=1, norm="backward"), axes=1
857-
)
822+
ftm = np.fft.fftshift(np.fft.fft(f, axis=1, norm="backward"), axes=1)
858823

859824
for t, theta in enumerate(thetas):
860-
861825
phi_ring_offset = (
862-
samples.p2phi_ring(t, 0, nside)
863-
if sampling.lower() == "healpix"
864-
else 0
826+
samples.p2phi_ring(t, 0, nside) if sampling.lower() == "healpix" else 0
865827
)
866828

867829
for el in range(max(L_lower, abs(spin)), L):
868-
869830
dl = recursions.turok.compute_slice(theta, el, L, -spin, reality)
870831

871832
elfactor = np.sqrt((2 * el + 1) / (4 * np.pi))
@@ -974,20 +935,16 @@ def _compute_forward_sov_fft_vectorized(
974935
t = t[:, :-1]
975936
ftm[:, L - 1 + m_offset :] = t
976937
else:
977-
ftm = np.fft.fftshift(
978-
np.fft.fft(f, axis=1, norm="backward"), axes=1
979-
)
938+
ftm = np.fft.fftshift(np.fft.fft(f, axis=1, norm="backward"), axes=1)
980939

981940
for t, theta in enumerate(thetas):
982-
983941
phase_shift = (
984942
samples.ring_phase_shift_hp(L, t, nside, True, reality)
985943
if sampling.lower() == "healpix"
986944
else 1.0
987945
)
988946

989947
for el in range(max(L_lower, abs(spin)), L):
990-
991948
dl = recursions.turok.compute_slice(theta, el, L, -spin, reality)
992949

993950
elfactor = np.sqrt((2 * el + 1) / (4 * np.pi))

s2fft/base_transforms/wigner.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,7 @@ def inverse(
7676
if reality:
7777
f = np.fft.irfft(fban[N - 1 :], 2 * N - 1, axis=ax, norm="forward")
7878
else:
79-
f = np.fft.ifft(
80-
np.fft.ifftshift(fban, axes=ax), axis=ax, norm="forward"
81-
)
79+
f = np.fft.ifft(np.fft.ifftshift(fban, axes=ax), axis=ax, norm="forward")
8280

8381
return f
8482

0 commit comments

Comments
 (0)