|
3 | 3 | import numpy as np |
4 | 4 | import pyssht as ssht |
5 | 5 | import pytest |
| 6 | +import torch |
6 | 7 |
|
7 | 8 | from s2fft.recursions.price_mcewen import generate_precomputes |
8 | 9 | from s2fft.sampling import s2_samples as samples |
|
15 | 16 | spin_to_test = [-2, 0, 1] |
16 | 17 | nside_to_test = [4, 5] |
17 | 18 | sampling_to_test = ["mw", "mwss", "dh", "gl"] |
18 | | -method_to_test = ["numpy", "jax"] |
| 19 | +method_to_test = ["numpy", "jax", "torch"] |
19 | 20 | reality_to_test = [False, True] |
20 | 21 | multiple_gpus = [False, True] |
21 | 22 |
|
@@ -59,7 +60,7 @@ def test_transform_inverse( |
59 | 60 | else: |
60 | 61 | precomps = None |
61 | 62 | f = spherical.inverse( |
62 | | - flm, |
| 63 | + torch.from_numpy(flm) if method == "orch" else flm, |
63 | 64 | L, |
64 | 65 | spin, |
65 | 66 | sampling=sampling, |
@@ -87,10 +88,9 @@ def test_transform_inverse_healpix( |
87 | 88 | flm = flm_generator(L=L, spin=0, reality=True) |
88 | 89 | flm_hp = samples.flm_2d_to_hp(flm, L) |
89 | 90 | f_check = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1) |
90 | | - |
91 | 91 | precomps = generate_precomputes(L, 0, sampling, nside, False) |
92 | 92 | f = spherical.inverse( |
93 | | - flm, |
| 93 | + torch.from_numpy(flm) if method == "torch" else flm, |
94 | 94 | L, |
95 | 95 | spin=0, |
96 | 96 | nside=nside, |
@@ -142,8 +142,9 @@ def test_transform_forward( |
142 | 142 | precomps = generate_precomputes(L, spin, sampling, None, True, L_lower) |
143 | 143 | else: |
144 | 144 | precomps = None |
| 145 | + |
145 | 146 | flm_check = spherical.forward( |
146 | | - f, |
| 147 | + torch.from_numpy(f) if method == "torch" else f, |
147 | 148 | L, |
148 | 149 | spin, |
149 | 150 | sampling=sampling, |
@@ -173,10 +174,9 @@ def test_transform_forward_healpix( |
173 | 174 | flm = flm_generator(L=L, spin=0, reality=True) |
174 | 175 | flm_hp = samples.flm_2d_to_hp(flm, L) |
175 | 176 | f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1) |
176 | | - |
177 | 177 | precomps = generate_precomputes(L, 0, sampling, nside, True) |
178 | 178 | flm_check = spherical.forward( |
179 | | - f, |
| 179 | + torch.from_numpy(f) if method == "torch" else f, |
180 | 180 | L, |
181 | 181 | spin=0, |
182 | 182 | nside=nside, |
|
0 commit comments