Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Random123 = "74087812-796a-5b5d-8853-05524746bad3"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[sources]
ReactantCore = {path = "lib/ReactantCore"}
Expand All @@ -67,6 +68,7 @@ ReactantRandom123Ext = "Random123"
ReactantSpecialFunctionsExt = "SpecialFunctions"
ReactantStatisticsExt = "Statistics"
ReactantYaoBlocksExt = "YaoBlocks"
ReactantZygoteExt = "Zygote"

[compat]
AbstractFFTs = "1.5"
Expand Down Expand Up @@ -107,6 +109,7 @@ Sockets = "1.10"
SpecialFunctions = "2.4"
Statistics = "1.10"
YaoBlocks = "0.13, 0.14"
Zygote = "0.7"
julia = "1.10"
unzip_jll = "6"

Expand Down
5 changes: 5 additions & 0 deletions docs/src/api/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ Reactant.PrecisionConfig
Reactant.DotGeneralAlgorithm
```

### Zygote Overlay

- `OVERLAY_ZYGOTE_CALLS`: Whether to overlay `Zygote.gradient` calls with `Enzyme.autodiff`
calls.

## Environment Variables

The following environment variables can be used to configure Reactant.
Expand Down
27 changes: 27 additions & 0 deletions ext/ReactantZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
module ReactantZygoteExt

using Reactant:
Reactant, CallWithReactant, @reactant_overlay, use_overlayed_version, call_with_reactant
using Zygote: Zygote
using Enzyme: Enzyme, Reverse, Active, Const, Duplicated

# TODO: overload the following as well
# - Zygote.pullback
# - Zygote.jacobian
# - Zygote.hessian

@reactant_overlay function Zygote.gradient(f::F, args...) where {F}
# TODO: check `f` as well once #1642 is merged
if Reactant.OVERLAY_ZYGOTE_CALLS[] && use_overlayed_version(args)
@warn "Reactant doesn't support using Zygote for computing gradients. Replacing \
`Zygote.gradient` with `Enzyme.autodiff` call. Please update your code to \
not use `Zygote.gradient` and instead use `Enzyme.gradient` inside \
`Reactant.@compile`. If this behavior is undesirable, set the \
`overlay_zygote_calls` scoped value via `Reactant.with_config` to `false`."
return Enzyme.gradient(Reverse, Const(f), args...)
else
return Base.inferencebarrier(Zygote.gradient)(CallWithReactant(f), args...)
end
end

end
11 changes: 11 additions & 0 deletions src/Configuration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ scope will use the provided values.
or [`PrecisionConfig`](@ref). Defaults to `PrecisionConfig.DEFAULT`.
- `convolution_precision`: Precision for `stablehlo.convolution`. Can be `nothing`,
or [`PrecisionConfig`](@ref). Defaults to `PrecisionConfig.DEFAULT`.

### Zygote Overlay

- `overlay_zygote_calls`: Whether to overlay `Zygote.gradient` calls with
`Enzyme.autodiff` calls. Defaults to `true`.
"""
function with_config(
f;
Expand All @@ -38,6 +43,7 @@ function with_config(
convolution_precision=missing,
lower_partialsort_to_approx_top_k=missing,
fallback_approx_top_k_lowering=missing,
overlay_zygote_calls=missing,
)
config_vars = ()
dot_general_algorithm !== missing &&
Expand All @@ -58,6 +64,8 @@ function with_config(
FALLBACK_APPROX_TOP_K_LOWERING => fallback_approx_top_k_lowering,
)
)
overlay_zygote_calls !== missing &&
(config_vars = (config_vars..., OVERLAY_ZYGOTE_CALLS => overlay_zygote_calls))

return ScopedValues.with(f, config_vars...)
end
Expand Down Expand Up @@ -379,3 +387,6 @@ function DotGeneralAlgorithm(

return nothing
end

# Overlay Zygote.jl
const OVERLAY_ZYGOTE_CALLS = ScopedValue(true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like it might be better to default to false, at least to start? could be convinced either way

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like the idea of a default that errors out almost always. Rn we should always work with the switching. If anyone really disagrees with switching they can easy opt-out in which case their code will just crash

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose that's fair, and I'm okay with this. I guess the specific caveat being that I think we should reserve the right to swap the default (and do so once either more downstream things are set to use enzyme properly and/or we fix broadcasting or other limitations)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and yeah I do agree it's a lot better to get an early error message with a backtrace where it's at least possible to see where to do the switch

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the text a bit more.

2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Adapt = "4.1"
Expand Down Expand Up @@ -65,6 +66,7 @@ StableRNGs = "1"
Statistics = "1.10"
StatsBase = "0.34"
Test = "1.10"
Zygote = "0.7"

[extras]
Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"
22 changes: 22 additions & 0 deletions test/integration/zygote.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using Zygote, Reactant, Enzyme, Test

sumabs2(x) = sum(abs2, x)

@testset "Zygote" begin
@testset "Zygote.gradient" begin
x = Reactant.to_rarray(rand(Float32, 32, 10))

zyg_grad = @jit Zygote.gradient(sumabs2, x)
enz_grad = @jit Enzyme.gradient(Reverse, Const(sumabs2), x)
@test zyg_grad[1] isa Reactant.ConcreteRArray
@test enz_grad[1] ≈ zyg_grad[1]

@testset "Disable Overlay" begin
@test_throws Zygote.CompileError Reactant.with_config(;
overlay_zygote_calls=false
) do
@jit Zygote.gradient(sumabs2, x)
end
end
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
@safetestset "Python" include("integration/python.jl")
@safetestset "Optimisers" include("integration/optimisers.jl")
@safetestset "FillArrays" include("integration/fillarrays.jl")
@safetestset "Zygote" include("integration/zygote.jl")
end

if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks"
Expand Down
Loading