From 704863386019e07104a25b2dd1cfcb57b6a7ab06 Mon Sep 17 00:00:00 2001 From: Kevin Mulder <33317219+kmulderdas@users.noreply.github.com> Date: Wed, 9 Jul 2025 17:49:16 +0100 Subject: [PATCH 1/3] Update custom_ops.py Small compatibility change which disables jitting on the s2fft side, in turn enables higher level jitting in s2ai. --- s2fft/precompute_transforms/custom_ops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/s2fft/precompute_transforms/custom_ops.py b/s2fft/precompute_transforms/custom_ops.py index b6a84da3..03cfe63f 100644 --- a/s2fft/precompute_transforms/custom_ops.py +++ b/s2fft/precompute_transforms/custom_ops.py @@ -86,7 +86,7 @@ def wigner_subset_to_s2( return np.fft.ifft(x, axis=-2, norm="forward") -@partial(jit, static_argnums=(3, 4)) +# @partial(jit, static_argnums=(3, 4)) def wigner_subset_to_s2_jax( flmn: jnp.ndarray, spins: jnp.ndarray, @@ -209,7 +209,7 @@ def so3_to_wigner_subset( return s2_to_wigner_subset(x, spins, DW, L, sampling) -@partial(jit, static_argnums=(3, 4, 5)) +# @partial(jit, static_argnums=(3, 4, 5)) def so3_to_wigner_subset_jax( f: jnp.ndarray, spins: jnp.ndarray, @@ -338,7 +338,7 @@ def s2_to_wigner_subset( return x * (2.0 * np.pi) ** 2 -@partial(jit, static_argnums=(3, 4)) +# @partial(jit, static_argnums=(3, 4)) def s2_to_wigner_subset_jax( fs: jnp.ndarray, spins: jnp.ndarray, From ba14087fb8c829c7998bd15a34c2aacd62dadc21 Mon Sep 17 00:00:00 2001 From: Kevin Mulder <33317219+kmulderdas@users.noreply.github.com> Date: Thu, 31 Jul 2025 18:25:46 +0100 Subject: [PATCH 2/3] Update custom_ops.py Removed commented lines for linting purposes --- s2fft/precompute_transforms/custom_ops.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/s2fft/precompute_transforms/custom_ops.py b/s2fft/precompute_transforms/custom_ops.py index 03cfe63f..343030b2 100644 --- a/s2fft/precompute_transforms/custom_ops.py +++ b/s2fft/precompute_transforms/custom_ops.py @@ -86,7 +86,6 @@ def wigner_subset_to_s2( return np.fft.ifft(x, axis=-2, norm="forward") -# @partial(jit, static_argnums=(3, 4)) def wigner_subset_to_s2_jax( flmn: jnp.ndarray, spins: jnp.ndarray, @@ -209,7 +208,6 @@ def so3_to_wigner_subset( return s2_to_wigner_subset(x, spins, DW, L, sampling) -# @partial(jit, static_argnums=(3, 4, 5)) def so3_to_wigner_subset_jax( f: jnp.ndarray, spins: jnp.ndarray, @@ -338,7 +336,6 @@ def s2_to_wigner_subset( return x * (2.0 * np.pi) ** 2 -# @partial(jit, static_argnums=(3, 4)) def s2_to_wigner_subset_jax( fs: jnp.ndarray, spins: jnp.ndarray, From 1a63620b1885d90ded45638ec679e45496cf464f Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Mon, 11 Aug 2025 16:42:26 +0100 Subject: [PATCH 3/3] Removing now unused imports --- s2fft/precompute_transforms/custom_ops.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/s2fft/precompute_transforms/custom_ops.py b/s2fft/precompute_transforms/custom_ops.py index 343030b2..625c0350 100644 --- a/s2fft/precompute_transforms/custom_ops.py +++ b/s2fft/precompute_transforms/custom_ops.py @@ -1,8 +1,5 @@ -from functools import partial - import jax.numpy as jnp import numpy as np -from jax import jit def wigner_subset_to_s2(