Skip to content

Commit ff8513d

Browse files
committed
Include torch wrappers in tests
1 parent 364a6f6 commit ff8513d

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

tests/test_spherical_transform.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
import pyssht as ssht
55
import pytest
6+
import torch
67

78
from s2fft.recursions.price_mcewen import generate_precomputes
89
from s2fft.sampling import s2_samples as samples
@@ -15,7 +16,7 @@
1516
spin_to_test = [-2, 0, 1]
1617
nside_to_test = [4, 5]
1718
sampling_to_test = ["mw", "mwss", "dh", "gl"]
18-
method_to_test = ["numpy", "jax"]
19+
method_to_test = ["numpy", "jax", "torch"]
1920
reality_to_test = [False, True]
2021
multiple_gpus = [False, True]
2122

@@ -59,7 +60,7 @@ def test_transform_inverse(
5960
else:
6061
precomps = None
6162
f = spherical.inverse(
62-
flm,
63+
torch.from_numpy(flm) if method == "orch" else flm,
6364
L,
6465
spin,
6566
sampling=sampling,
@@ -87,10 +88,9 @@ def test_transform_inverse_healpix(
8788
flm = flm_generator(L=L, spin=0, reality=True)
8889
flm_hp = samples.flm_2d_to_hp(flm, L)
8990
f_check = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1)
90-
9191
precomps = generate_precomputes(L, 0, sampling, nside, False)
9292
f = spherical.inverse(
93-
flm,
93+
torch.from_numpy(flm) if method == "torch" else flm,
9494
L,
9595
spin=0,
9696
nside=nside,
@@ -142,8 +142,9 @@ def test_transform_forward(
142142
precomps = generate_precomputes(L, spin, sampling, None, True, L_lower)
143143
else:
144144
precomps = None
145+
145146
flm_check = spherical.forward(
146-
f,
147+
torch.from_numpy(f) if method == "torch" else f,
147148
L,
148149
spin,
149150
sampling=sampling,
@@ -173,10 +174,9 @@ def test_transform_forward_healpix(
173174
flm = flm_generator(L=L, spin=0, reality=True)
174175
flm_hp = samples.flm_2d_to_hp(flm, L)
175176
f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1)
176-
177177
precomps = generate_precomputes(L, 0, sampling, nside, True)
178178
flm_check = spherical.forward(
179-
f,
179+
torch.from_numpy(f) if method == "torch" else f,
180180
L,
181181
spin=0,
182182
nside=nside,

0 commit comments

Comments
 (0)