@@ -62,14 +62,14 @@ def inverse(
6262 + "Defering to complex transform." ,
6363 stacklevel = 2 ,
6464 )
65- if method == "numpy" :
66- return inverse_transform (flm , kernel , L , sampling , reality , spin , nside )
67- elif method == "jax" :
68- return inverse_transform_jax (flm , kernel , L , sampling , reality , spin , nside )
69- elif method == "torch" :
70- return inverse_transform_torch (flm , kernel , L , sampling , reality , spin , nside )
71- else :
65+ inverse_functions = {
66+ "numpy" : inverse_transform ,
67+ "jax" : inverse_transform_jax ,
68+ "torch" : inverse_transform_torch ,
69+ }
70+ if method not in inverse_functions :
7271 raise ValueError (f"Method { method } not recognised." )
72+ return inverse_functions [method ](flm , kernel , L , sampling , reality , spin , nside )
7373
7474
7575def inverse_transform (
@@ -337,14 +337,14 @@ def forward(
337337 + "Defering to complex transform." ,
338338 stacklevel = 2 ,
339339 )
340- if method == "numpy" :
341- return forward_transform (f , kernel , L , sampling , reality , spin , nside )
342- elif method == "jax" :
343- return forward_transform_jax (f , kernel , L , sampling , reality , spin , nside )
344- elif method == "torch" :
345- return forward_transform_torch (f , kernel , L , sampling , reality , spin , nside )
346- else :
340+ forward_functions = {
341+ "numpy" : forward_transform ,
342+ "jax" : forward_transform_jax ,
343+ "torch" : forward_transform_torch ,
344+ }
345+ if method not in forward_functions :
347346 raise ValueError (f"Method { method } not recognised." )
347+ return forward_functions [method ](f , kernel , L , sampling , reality , spin , nside )
348348
349349
350350def forward_transform (
0 commit comments