-
Couldn't load subscription status.
- Fork 33
feat: overlay Zygote.gradient and use Enzyme instead #1658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
7c83518
bc56a8a
46e1792
2a27445
6dca23f
be5f2a7
6602657
f8785aa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| module ReactantZygoteExt | ||
|
|
||
| using Reactant: | ||
| Reactant, CallWithReactant, @reactant_overlay, use_overlayed_version, call_with_reactant | ||
| using Zygote: Zygote | ||
| using Enzyme: Enzyme, Reverse, Active, Const, Duplicated | ||
|
|
||
| # TODO: overload the following as well | ||
| # - Zygote.pullback | ||
| # - Zygote.jacobian | ||
| # - Zygote.hessian | ||
|
|
||
| @reactant_overlay function Zygote.gradient(f::F, args...) where {F} | ||
| # TODO: check `f` as well once #1642 is merged | ||
| if use_overlayed_version(args) && Reactant.OVERLAY_ZYGOTE_CALLS[] | ||
| @warn "Reactant doesn't support using Zygote for computing gradients. Replacing \ | ||
| `Zygote.gradient` with `Enzyme.autodiff` call. Please update your code to \ | ||
| not use `Zygote.gradient` and instead use `Enzyme.gradient` inside \ | ||
| `Reactant.@compile`." | ||
| dargs = map(Enzyme.make_zero, args) | ||
| duplicated = map(Duplicated, args, dargs) | ||
| Reactant.overload_autodiff(Reverse, Const(f), Active, duplicated...) | ||
| return dargs | ||
| else | ||
| return Base.inferencebarrier(Zygote.gradient)(CallWithReactant(f), args...) | ||
| end | ||
| end | ||
|
|
||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,6 +30,11 @@ scope will use the provided values. | |
| or [`PrecisionConfig`](@ref). Defaults to `PrecisionConfig.DEFAULT`. | ||
| - `convolution_precision`: Precision for `stablehlo.convolution`. Can be `nothing`, | ||
| or [`PrecisionConfig`](@ref). Defaults to `PrecisionConfig.DEFAULT`. | ||
|
|
||
| ### Zygote Overlay | ||
|
|
||
| - `overlay_zygote_calls`: Whether to overlay `Zygote.gradient` calls with | ||
| `Enzyme.autodiff` calls. Defaults to `true`. | ||
| """ | ||
| function with_config( | ||
| f; | ||
|
|
@@ -38,6 +43,7 @@ function with_config( | |
| convolution_precision=missing, | ||
| lower_partialsort_to_approx_top_k=missing, | ||
| fallback_approx_top_k_lowering=missing, | ||
| overlay_zygote_calls=missing, | ||
| ) | ||
| config_vars = () | ||
| dot_general_algorithm !== missing && | ||
|
|
@@ -58,6 +64,9 @@ function with_config( | |
| FALLBACK_APPROX_TOP_K_LOWERING => fallback_approx_top_k_lowering, | ||
| ) | ||
| ) | ||
| overlay_zygote_calls !== missing && ( | ||
| config_vars = (config_vars..., OVERLAY_ZYGOTE_CALLS => overlay_zygote_calls) | ||
| ) | ||
avik-pal marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| return ScopedValues.with(f, config_vars...) | ||
| end | ||
|
|
@@ -379,3 +388,6 @@ function DotGeneralAlgorithm( | |
|
|
||
| return nothing | ||
| end | ||
|
|
||
| # Overlay Zygote.jl | ||
| const OVERLAY_ZYGOTE_CALLS = ScopedValue(true) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like it might be better to default to false, at least to start? could be convinced either way There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't like the idea of a default that errors out almost always. Rn we should always work with the switching. If anyone really disagrees with switching they can easy opt-out in which case their code will just crash There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suppose that's fair, and I'm okay with this. I guess the specific caveat being that I think we should reserve the right to swap the default (and do so once either more downstream things are set to use enzyme properly and/or we fix broadcasting or other limitations) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and yeah I do agree it's a lot better to get an early error message with a backtrace where it's at least possible to see where to do the switch There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated the text a bit more. |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| using Zygote, Reactant, Enzyme, Test | ||
|
|
||
| sumabs2(x) = sum(abs2, x) | ||
|
|
||
| @testset "Zygote" begin | ||
| @testset "Zygote.gradient" begin | ||
| x = Reactant.to_rarray(rand(Float32, 32, 10)) | ||
|
|
||
| zyg_grad = @jit Zygote.gradient(sumabs2, x) | ||
| enz_grad = @jit Enzyme.gradient(Reverse, Const(sumabs2), x) | ||
| @test zyg_grad[1] isa Reactant.ConcreteRArray | ||
| @test enz_grad[1] ≈ zyg_grad[1] | ||
| end | ||
| end |
Uh oh!
There was an error while loading. Please reload this page.