diff --git a/s2fft/precompute_transforms/custom_ops.py b/s2fft/precompute_transforms/custom_ops.py index b6a84da3..625c0350 100644 --- a/s2fft/precompute_transforms/custom_ops.py +++ b/s2fft/precompute_transforms/custom_ops.py @@ -1,8 +1,5 @@ -from functools import partial - import jax.numpy as jnp import numpy as np -from jax import jit def wigner_subset_to_s2( @@ -86,7 +83,6 @@ def wigner_subset_to_s2( return np.fft.ifft(x, axis=-2, norm="forward") -@partial(jit, static_argnums=(3, 4)) def wigner_subset_to_s2_jax( flmn: jnp.ndarray, spins: jnp.ndarray, @@ -209,7 +205,6 @@ def so3_to_wigner_subset( return s2_to_wigner_subset(x, spins, DW, L, sampling) -@partial(jit, static_argnums=(3, 4, 5)) def so3_to_wigner_subset_jax( f: jnp.ndarray, spins: jnp.ndarray, @@ -338,7 +333,6 @@ def s2_to_wigner_subset( return x * (2.0 * np.pi) ** 2 -@partial(jit, static_argnums=(3, 4)) def s2_to_wigner_subset_jax( fs: jnp.ndarray, spins: jnp.ndarray,