Skip to content

Commit cd3988d

Browse files
committed
Refactor method dispatch logic in Wigner transforms
1 parent ff8513d commit cd3988d

File tree

1 file changed

+53
-22
lines changed

1 file changed

+53
-22
lines changed

s2fft/transforms/wigner.py

Lines changed: 53 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -81,21 +81,30 @@ def inverse(
8181
IEEE Transactions on Signal Processing 59 (2011): 5876-5887.
8282
8383
"""
84-
if N >= 8 and method in ["numpy", "jax"]:
84+
if method not in _inverse_functions:
85+
raise ValueError(f"Method {method} not recognised.")
86+
87+
if N >= 8 and method in ("numpy", "jax"):
8588
raise Warning("Recursive transform may provide lower precision beyond N ~ 8")
8689

87-
if method == "numpy":
88-
return inverse_numpy(flmn, L, N, nside, sampling, reality, precomps, L_lower)
89-
elif method == "jax":
90-
return inverse_jax(flmn, L, N, nside, sampling, reality, precomps, L_lower)
91-
elif method == "jax_ssht":
90+
inverse_kwargs = {
91+
"flmn": flmn,
92+
"L": L,
93+
"N": N,
94+
"L_lower": L_lower,
95+
"sampling": sampling,
96+
"reality": reality,
97+
}
98+
99+
if method in ("jax", "numpy"):
100+
inverse_kwargs.update(nside=nside, precomps=precomps)
101+
102+
if method == "jax_ssht":
92103
if sampling.lower() == "healpix":
93104
raise ValueError("SSHT does not support healpix sampling.")
94-
return inverse_jax_ssht(flmn, L, N, L_lower, sampling, reality, _ssht_backend)
95-
else:
96-
raise ValueError(
97-
f"Implementation {method} not recognised. Should be either numpy or jax."
98-
)
105+
inverse_kwargs["_ssht_backend"] = _ssht_backend
106+
107+
return _inverse_functions[method](**inverse_kwargs)
99108

100109

101110
def inverse_numpy(
@@ -401,21 +410,30 @@ def forward(
401410
IEEE Transactions on Signal Processing 59 (2011): 5876-5887.
402411
403412
"""
404-
if N >= 8 and method in ["numpy", "jax"]:
413+
if method not in _inverse_functions:
414+
raise ValueError(f"Method {method} not recognised.")
415+
416+
if N >= 8 and method in ("numpy", "jax"):
405417
raise Warning("Recursive transform may provide lower precision beyond N ~ 8")
406418

407-
if method == "numpy":
408-
return forward_numpy(f, L, N, nside, sampling, reality, precomps, L_lower)
409-
elif method == "jax":
410-
return forward_jax(f, L, N, nside, sampling, reality, precomps, L_lower)
411-
elif method == "jax_ssht":
419+
forward_kwargs = {
420+
"f": f,
421+
"L": L,
422+
"N": N,
423+
"L_lower": L_lower,
424+
"sampling": sampling,
425+
"reality": reality,
426+
}
427+
428+
if method in ("jax", "numpy"):
429+
forward_kwargs.update(nside=nside, precomps=precomps)
430+
431+
if method == "jax_ssht":
412432
if sampling.lower() == "healpix":
413433
raise ValueError("SSHT does not support healpix sampling.")
414-
return forward_jax_ssht(f, L, N, L_lower, sampling, reality, _ssht_backend)
415-
else:
416-
raise ValueError(
417-
f"Implementation {method} not recognised. Should be either numpy or jax."
418-
)
434+
forward_kwargs["_ssht_backend"] = _ssht_backend
435+
436+
return _forward_functions[method](**forward_kwargs)
419437

420438

421439
def forward_numpy(
@@ -805,3 +823,16 @@ def _fban_to_f(fban: jnp.ndarray, L: int, N: int, reality: bool = False) -> jnp.
805823
else:
806824
f = jnp.fft.ifft(jnp.fft.ifftshift(fban, axes=-3), axis=-3, norm="forward")
807825
return f
826+
827+
828+
_inverse_functions = {
829+
"numpy": inverse_numpy,
830+
"jax": inverse_jax,
831+
"jax_ssht": inverse_jax_ssht,
832+
}
833+
834+
_forward_functions = {
835+
"numpy": forward_numpy,
836+
"jax": forward_jax,
837+
"jax_ssht": forward_jax_ssht,
838+
}

0 commit comments

Comments
 (0)