Skip to content

Conversation

@kmulderdas
Copy link
Contributor

Small compatibility change which disables jitting on the s2fft side, in turn enables higher level jitting in s2ai.

Small compatibility change which disables jitting on the s2fft side, in turn enables higher level jitting in s2ai.
@codecov
Copy link

codecov bot commented Jul 9, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 96.50%. Comparing base (d77e9cb) to head (1a63620).
⚠️ Report is 3 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #315      +/-   ##
==========================================
- Coverage   96.51%   96.50%   -0.01%     
==========================================
  Files          32       32              
  Lines        3439     3434       -5     
==========================================
- Hits         3319     3314       -5     
  Misses        120      120              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@matt-graham matt-graham left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kmulderdas - is having these functions jitted blocking you jitting a function which calls them in s2ai? Generally jax.jit wrapped functions should be able to be arbitrarily nested and there is some suggestion it can help with compilation times when a lower-level function is reused multiple times in a higher level function. I suspect removing the jit decorators on the utility functions here won't have any major performance implications but it would be good to understand why its causing issues as we widely apply jax.jit to other functions in s2fft which are likely to be contained in higher-level jitted functions.



@partial(jit, static_argnums=(3, 4))
# @partial(jit, static_argnums=(3, 4))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# @partial(jit, static_argnums=(3, 4))

We generally shouldn't comment out code as we can always recover snippets from git history - this is likely to be what is causing the linting failures.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These jitted functions are called in the lifted convolution layers in s2ai. For performance and general functionality reasons it is desirable to have the flax model, build from these layers, be traceable and jittable at the top level. These utility functions specifically cause errors when trying to trace/jit at the aforementioned level. It is not clear to me why the other imported jitted functions from s2fft or the ones natively defined in s2ai don't break in the same way.



@partial(jit, static_argnums=(3, 4, 5))
# @partial(jit, static_argnums=(3, 4, 5))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# @partial(jit, static_argnums=(3, 4, 5))



@partial(jit, static_argnums=(3, 4))
# @partial(jit, static_argnums=(3, 4))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# @partial(jit, static_argnums=(3, 4))

Removed commented lines for linting purposes
@matt-graham
Copy link
Collaborator

matt-graham commented Aug 11, 2025

Hmm actually linting check is still failing, looks like some reformatting is still happening. Will reformat locally and see if fixes.

EDIT: Turns out Ruff check was failing due to now unused imports not having been removed. Hopefully last commit should resolve this.

@matt-graham matt-graham merged commit 21d2d0c into main Aug 11, 2025
14 checks passed
@matt-graham matt-graham deleted the km/s2ai-compatability branch August 11, 2025 16:12
ASKabalan pushed a commit that referenced this pull request Nov 11, 2025
* Update custom_ops.py

Small compatibility change which disables jitting on the s2fft side, in turn enables higher level jitting in s2ai.

* Update custom_ops.py

Removed commented lines for linting purposes

* Removing now unused imports

---------

Co-authored-by: Matt Graham <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants