-
Notifications
You must be signed in to change notification settings - Fork 14
Update custom_ops.py #315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update custom_ops.py #315
Conversation
Small compatibility change which disables jitting on the s2fft side, in turn enables higher level jitting in s2ai.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #315 +/- ##
==========================================
- Coverage 96.51% 96.50% -0.01%
==========================================
Files 32 32
Lines 3439 3434 -5
==========================================
- Hits 3319 3314 -5
Misses 120 120 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
matt-graham
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kmulderdas - is having these functions jitted blocking you jitting a function which calls them in s2ai? Generally jax.jit wrapped functions should be able to be arbitrarily nested and there is some suggestion it can help with compilation times when a lower-level function is reused multiple times in a higher level function. I suspect removing the jit decorators on the utility functions here won't have any major performance implications but it would be good to understand why its causing issues as we widely apply jax.jit to other functions in s2fft which are likely to be contained in higher-level jitted functions.
|
|
||
|
|
||
| @partial(jit, static_argnums=(3, 4)) | ||
| # @partial(jit, static_argnums=(3, 4)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # @partial(jit, static_argnums=(3, 4)) |
We generally shouldn't comment out code as we can always recover snippets from git history - this is likely to be what is causing the linting failures.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These jitted functions are called in the lifted convolution layers in s2ai. For performance and general functionality reasons it is desirable to have the flax model, build from these layers, be traceable and jittable at the top level. These utility functions specifically cause errors when trying to trace/jit at the aforementioned level. It is not clear to me why the other imported jitted functions from s2fft or the ones natively defined in s2ai don't break in the same way.
|
|
||
|
|
||
| @partial(jit, static_argnums=(3, 4, 5)) | ||
| # @partial(jit, static_argnums=(3, 4, 5)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # @partial(jit, static_argnums=(3, 4, 5)) |
|
|
||
|
|
||
| @partial(jit, static_argnums=(3, 4)) | ||
| # @partial(jit, static_argnums=(3, 4)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # @partial(jit, static_argnums=(3, 4)) |
Removed commented lines for linting purposes
|
Hmm actually linting check is still failing, looks like some reformatting is still happening. Will reformat locally and see if fixes. EDIT: Turns out Ruff check was failing due to now unused imports not having been removed. Hopefully last commit should resolve this. |
* Update custom_ops.py Small compatibility change which disables jitting on the s2fft side, in turn enables higher level jitting in s2ai. * Update custom_ops.py Removed commented lines for linting purposes * Removing now unused imports --------- Co-authored-by: Matt Graham <[email protected]>
Small compatibility change which disables jitting on the s2fft side, in turn enables higher level jitting in s2ai.