Skip to content

Commit bdf8fc9

Browse files
committed
Deduplicate dispatching logic in forward spherical transform
1 parent 909e6f1 commit bdf8fc9

File tree

1 file changed

+16
-29
lines changed

1 file changed

+16
-29
lines changed

s2fft/transforms/spherical.py

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -401,35 +401,22 @@ def forward(
401401
if spin >= 8 and method in ["numpy", "jax"]:
402402
raise Warning("Recursive transform may provide lower precision beyond spin ~ 8")
403403

404-
if method == "numpy":
405-
return forward_numpy(f, L, spin, nside, sampling, reality, precomps, L_lower)
406-
elif method == "jax":
407-
return forward_jax(
408-
f,
409-
L,
410-
spin,
411-
nside,
412-
sampling,
413-
reality,
414-
precomps,
415-
spmd,
416-
L_lower,
417-
use_healpix_custom_primitive=False,
418-
)
419-
elif method == "cuda":
420-
return forward_jax(
421-
f,
422-
L,
423-
spin,
424-
nside,
425-
sampling,
426-
reality,
427-
precomps,
428-
spmd,
429-
L_lower,
430-
use_healpix_custom_primitive=True,
431-
)
432-
404+
if method in {"numpy", "jax", "cuda"}:
405+
kwargs = {
406+
"f": f,
407+
"L": L,
408+
"spin": spin,
409+
"nside": nside,
410+
"sampling": sampling,
411+
"reality": reality,
412+
"precomps": precomps,
413+
"L_lower": L_lower,
414+
}
415+
if method in {"jax", "cuda"}:
416+
kwargs["spmd"] = spmd
417+
kwargs["use_healpix_custom_primitive"] = method == "cuda"
418+
forward_function = forward_numpy if method == "numpy" else forward_jax
419+
return forward_function(**kwargs)
433420
elif method == "jax_ssht":
434421
if sampling.lower() == "healpix":
435422
raise ValueError("SSHT does not support healpix sampling.")

0 commit comments

Comments
 (0)