@@ -81,21 +81,30 @@ def inverse(
8181 IEEE Transactions on Signal Processing 59 (2011): 5876-5887.
8282
8383 """
84- if N >= 8 and method in ["numpy" , "jax" ]:
84+ if method not in _inverse_functions :
85+ raise ValueError (f"Method { method } not recognised." )
86+
87+ if N >= 8 and method in ("numpy" , "jax" ):
8588 raise Warning ("Recursive transform may provide lower precision beyond N ~ 8" )
8689
87- if method == "numpy" :
88- return inverse_numpy (flmn , L , N , nside , sampling , reality , precomps , L_lower )
89- elif method == "jax" :
90- return inverse_jax (flmn , L , N , nside , sampling , reality , precomps , L_lower )
91- elif method == "jax_ssht" :
90+ inverse_kwargs = {
91+ "flmn" : flmn ,
92+ "L" : L ,
93+ "N" : N ,
94+ "L_lower" : L_lower ,
95+ "sampling" : sampling ,
96+ "reality" : reality ,
97+ }
98+
99+ if method in ("jax" , "numpy" ):
100+ inverse_kwargs .update (nside = nside , precomps = precomps )
101+
102+ if method == "jax_ssht" :
92103 if sampling .lower () == "healpix" :
93104 raise ValueError ("SSHT does not support healpix sampling." )
94- return inverse_jax_ssht (flmn , L , N , L_lower , sampling , reality , _ssht_backend )
95- else :
96- raise ValueError (
97- f"Implementation { method } not recognised. Should be either numpy or jax."
98- )
105+ inverse_kwargs ["_ssht_backend" ] = _ssht_backend
106+
107+ return _inverse_functions [method ](** inverse_kwargs )
99108
100109
101110def inverse_numpy (
@@ -401,21 +410,30 @@ def forward(
401410 IEEE Transactions on Signal Processing 59 (2011): 5876-5887.
402411
403412 """
404- if N >= 8 and method in ["numpy" , "jax" ]:
413+ if method not in _inverse_functions :
414+ raise ValueError (f"Method { method } not recognised." )
415+
416+ if N >= 8 and method in ("numpy" , "jax" ):
405417 raise Warning ("Recursive transform may provide lower precision beyond N ~ 8" )
406418
407- if method == "numpy" :
408- return forward_numpy (f , L , N , nside , sampling , reality , precomps , L_lower )
409- elif method == "jax" :
410- return forward_jax (f , L , N , nside , sampling , reality , precomps , L_lower )
411- elif method == "jax_ssht" :
419+ forward_kwargs = {
420+ "f" : f ,
421+ "L" : L ,
422+ "N" : N ,
423+ "L_lower" : L_lower ,
424+ "sampling" : sampling ,
425+ "reality" : reality ,
426+ }
427+
428+ if method in ("jax" , "numpy" ):
429+ forward_kwargs .update (nside = nside , precomps = precomps )
430+
431+ if method == "jax_ssht" :
412432 if sampling .lower () == "healpix" :
413433 raise ValueError ("SSHT does not support healpix sampling." )
414- return forward_jax_ssht (f , L , N , L_lower , sampling , reality , _ssht_backend )
415- else :
416- raise ValueError (
417- f"Implementation { method } not recognised. Should be either numpy or jax."
418- )
434+ forward_kwargs ["_ssht_backend" ] = _ssht_backend
435+
436+ return _forward_functions [method ](** forward_kwargs )
419437
420438
421439def forward_numpy (
@@ -805,3 +823,16 @@ def _fban_to_f(fban: jnp.ndarray, L: int, N: int, reality: bool = False) -> jnp.
805823 else :
806824 f = jnp .fft .ifft (jnp .fft .ifftshift (fban , axes = - 3 ), axis = - 3 , norm = "forward" )
807825 return f
826+
827+
828+ _inverse_functions = {
829+ "numpy" : inverse_numpy ,
830+ "jax" : inverse_jax ,
831+ "jax_ssht" : inverse_jax_ssht ,
832+ }
833+
834+ _forward_functions = {
835+ "numpy" : forward_numpy ,
836+ "jax" : forward_jax ,
837+ "jax_ssht" : forward_jax_ssht ,
838+ }
0 commit comments