-
Couldn't load subscription status.
- Fork 33
feat: overlay Zygote.gradient and use Enzyme instead #1658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
ext/ReactantZygoteExt.jl
Outdated
|
|
||
| @reactant_overlay function Zygote.gradient(f::F, args...) where {F} | ||
| # TODO: check `f` as well once #1642 is merged | ||
| if use_overlayed_version(args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we do this, we should quite aggressively yell that we're gonig to do this -- I would even be okay saying to do this for each call [not even each callsite]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm also kind of debating if we want to have this behind a feature flag as well [as perhaps it is useful to compare the performance of zygote as a frontend vs us inside the compiler]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is what I was testing first, but we need to fix our broadcasting quirks before we can work through ChainRules rrules.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah....we should definitely fix that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will error inside Zygote but we have an option
Reactant.with_config(; overlay_zygote_calls=false) do
@jit Zygote.gradient(sumabs2, x)
end
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1658 +/- ##
==========================================
- Coverage 68.76% 68.55% -0.21%
==========================================
Files 103 104 +1
Lines 11380 11567 +187
==========================================
+ Hits 7825 7930 +105
- Misses 3555 3637 +82 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
| end | ||
|
|
||
| # Overlay Zygote.jl | ||
| const OVERLAY_ZYGOTE_CALLS = ScopedValue(true) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like it might be better to default to false, at least to start? could be convinced either way
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like the idea of a default that errors out almost always. Rn we should always work with the switching. If anyone really disagrees with switching they can easy opt-out in which case their code will just crash
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose that's fair, and I'm okay with this. I guess the specific caveat being that I think we should reserve the right to swap the default (and do so once either more downstream things are set to use enzyme properly and/or we fix broadcasting or other limitations)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and yeah I do agree it's a lot better to get an early error message with a backtrace where it's at least possible to see where to do the switch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated the text a bit more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At a high level, I'm not a big fan of magically replacing one package with another one. Since you're overlaying Zygote, one option would be to throw an informative error rather than doing the magic substitution? Or the idea is to let people keep their existing code? At least I appreciate this is documented (and kinda like the idea of showing a warning when this magic replacement happens, although warnings are easy to miss in automated runs, or could clutter log files)
Edit: ok, I see this was discussed at #1658 (comment)
The idea is to make people progressively switch to Enzyme inside Reactant. Earlier today one of our students had switched the models and parameters to use reactant (xref SciML/NeuralPDE.jl#967) but did not switch the AD she was using. Instead of crashing we now do the switching with a loud enough warning that users know that they need to eventually fix this behavior. |
|
marking as draft to avoid accidental merge before others have a chance to comment |
|
Best way to fix people's code and performance 10/10 |
|
I'm mildly in favour of this. |
I know @ChrisRackauckas will like this 😅