Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions s2fft/precompute_transforms/custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def wigner_subset_to_s2(
return np.fft.ifft(x, axis=-2, norm="forward")


@partial(jit, static_argnums=(3, 4))
# @partial(jit, static_argnums=(3, 4))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# @partial(jit, static_argnums=(3, 4))

We generally shouldn't comment out code as we can always recover snippets from git history - this is likely to be what is causing the linting failures.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These jitted functions are called in the lifted convolution layers in s2ai. For performance and general functionality reasons it is desirable to have the flax model, build from these layers, be traceable and jittable at the top level. These utility functions specifically cause errors when trying to trace/jit at the aforementioned level. It is not clear to me why the other imported jitted functions from s2fft or the ones natively defined in s2ai don't break in the same way.

def wigner_subset_to_s2_jax(
flmn: jnp.ndarray,
spins: jnp.ndarray,
Expand Down Expand Up @@ -209,7 +209,7 @@ def so3_to_wigner_subset(
return s2_to_wigner_subset(x, spins, DW, L, sampling)


@partial(jit, static_argnums=(3, 4, 5))
# @partial(jit, static_argnums=(3, 4, 5))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# @partial(jit, static_argnums=(3, 4, 5))

def so3_to_wigner_subset_jax(
f: jnp.ndarray,
spins: jnp.ndarray,
Expand Down Expand Up @@ -338,7 +338,7 @@ def s2_to_wigner_subset(
return x * (2.0 * np.pi) ** 2


@partial(jit, static_argnums=(3, 4))
# @partial(jit, static_argnums=(3, 4))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# @partial(jit, static_argnums=(3, 4))

def s2_to_wigner_subset_jax(
fs: jnp.ndarray,
spins: jnp.ndarray,
Expand Down
Loading