Skip to content

Commit b288167

Browse files
committed
Deduplicate precompute transform wrappers
1 parent f50eaee commit b288167

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

s2fft/precompute_transforms/spherical.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,14 @@ def inverse(
6262
+ "Defering to complex transform.",
6363
stacklevel=2,
6464
)
65-
if method == "numpy":
66-
return inverse_transform(flm, kernel, L, sampling, reality, spin, nside)
67-
elif method == "jax":
68-
return inverse_transform_jax(flm, kernel, L, sampling, reality, spin, nside)
69-
elif method == "torch":
70-
return inverse_transform_torch(flm, kernel, L, sampling, reality, spin, nside)
71-
else:
65+
inverse_functions = {
66+
"numpy": inverse_transform,
67+
"jax": inverse_transform_jax,
68+
"torch": inverse_transform_torch,
69+
}
70+
if method not in inverse_functions:
7271
raise ValueError(f"Method {method} not recognised.")
72+
return inverse_functions[method](flm, kernel, L, sampling, reality, spin, nside)
7373

7474

7575
def inverse_transform(
@@ -337,14 +337,14 @@ def forward(
337337
+ "Defering to complex transform.",
338338
stacklevel=2,
339339
)
340-
if method == "numpy":
341-
return forward_transform(f, kernel, L, sampling, reality, spin, nside)
342-
elif method == "jax":
343-
return forward_transform_jax(f, kernel, L, sampling, reality, spin, nside)
344-
elif method == "torch":
345-
return forward_transform_torch(f, kernel, L, sampling, reality, spin, nside)
346-
else:
340+
forward_functions = {
341+
"numpy": forward_transform,
342+
"jax": forward_transform_jax,
343+
"torch": forward_transform_torch,
344+
}
345+
if method not in forward_functions:
347346
raise ValueError(f"Method {method} not recognised.")
347+
return forward_functions[method](f, kernel, L, sampling, reality, spin, nside)
348348

349349

350350
def forward_transform(

0 commit comments

Comments
 (0)