Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Random123 = "74087812-796a-5b5d-8853-05524746bad3"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[sources]
ReactantCore = {path = "lib/ReactantCore"}
Expand All @@ -67,6 +68,7 @@ ReactantRandom123Ext = "Random123"
ReactantSpecialFunctionsExt = "SpecialFunctions"
ReactantStatisticsExt = "Statistics"
ReactantYaoBlocksExt = "YaoBlocks"
ReactantZygoteExt = "Zygote"

[compat]
AbstractFFTs = "1.5"
Expand Down Expand Up @@ -107,6 +109,7 @@ Sockets = "1.10"
SpecialFunctions = "2.4"
Statistics = "1.10"
YaoBlocks = "0.13, 0.14"
Zygote = "0.7"
julia = "1.10"
unzip_jll = "6"

Expand Down
5 changes: 5 additions & 0 deletions docs/src/api/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ Reactant.PrecisionConfig
Reactant.DotGeneralAlgorithm
```

### Zygote Overlay

- `OVERLAY_ZYGOTE_CALLS`: Whether to overlay `Zygote.gradient` calls with `Enzyme.autodiff`
calls.

## Environment Variables

The following environment variables can be used to configure Reactant.
Expand Down
29 changes: 29 additions & 0 deletions ext/ReactantZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
module ReactantZygoteExt

using Reactant:
Reactant, CallWithReactant, @reactant_overlay, use_overlayed_version, call_with_reactant
using Zygote: Zygote
using Enzyme: Enzyme, Reverse, Active, Const, Duplicated

# TODO: overload the following as well
# - Zygote.pullback
# - Zygote.jacobian
# - Zygote.hessian

@reactant_overlay function Zygote.gradient(f::F, args...) where {F}
# TODO: check `f` as well once #1642 is merged
if Reactant.OVERLAY_ZYGOTE_CALLS[] && use_overlayed_version(args)
@warn "Reactant doesn't support using Zygote for computing gradients. Replacing \
`Zygote.gradient` with `Enzyme.autodiff` call. Please update your code to \
not use `Zygote.gradient` and instead use `Enzyme.gradient` inside \
`Reactant.@compile`. If this behavior is undesirable, set the \
`overlay_zygote_calls` scoped value via `Reactant.with_config` to \
`false`.\n\nReactant can remove this switching without any breaking change \
and hence reliance on this behavior is strongly discouraged."
return Enzyme.gradient(Reverse, Const(f), args...)
else
return Base.inferencebarrier(Zygote.gradient)(CallWithReactant(f), args...)
end
end

end
11 changes: 11 additions & 0 deletions src/Configuration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ scope will use the provided values.
or [`PrecisionConfig`](@ref). Defaults to `PrecisionConfig.DEFAULT`.
- `convolution_precision`: Precision for `stablehlo.convolution`. Can be `nothing`,
or [`PrecisionConfig`](@ref). Defaults to `PrecisionConfig.DEFAULT`.

### Zygote Overlay

- `overlay_zygote_calls`: Whether to overlay `Zygote.gradient` calls with
`Enzyme.autodiff` calls. Defaults to `true`.
"""
function with_config(
f;
Expand All @@ -38,6 +43,7 @@ function with_config(
convolution_precision=missing,
lower_partialsort_to_approx_top_k=missing,
fallback_approx_top_k_lowering=missing,
overlay_zygote_calls=missing,
)
config_vars = ()
dot_general_algorithm !== missing &&
Expand All @@ -58,6 +64,8 @@ function with_config(
FALLBACK_APPROX_TOP_K_LOWERING => fallback_approx_top_k_lowering,
)
)
overlay_zygote_calls !== missing &&
(config_vars = (config_vars..., OVERLAY_ZYGOTE_CALLS => overlay_zygote_calls))

return ScopedValues.with(f, config_vars...)
end
Expand Down Expand Up @@ -379,3 +387,6 @@ function DotGeneralAlgorithm(

return nothing
end

# Overlay Zygote.jl
const OVERLAY_ZYGOTE_CALLS = ScopedValue(true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like it might be better to default to false, at least to start? could be convinced either way

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like the idea of a default that errors out almost always. Rn we should always work with the switching. If anyone really disagrees with switching they can easy opt-out in which case their code will just crash

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose that's fair, and I'm okay with this. I guess the specific caveat being that I think we should reserve the right to swap the default (and do so once either more downstream things are set to use enzyme properly and/or we fix broadcasting or other limitations)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and yeah I do agree it's a lot better to get an early error message with a backtrace where it's at least possible to see where to do the switch

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the text a bit more.

2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Adapt = "4.1"
Expand Down Expand Up @@ -65,6 +66,7 @@ StableRNGs = "1"
Statistics = "1.10"
StatsBase = "0.34"
Test = "1.10"
Zygote = "0.7"

[extras]
Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"
22 changes: 22 additions & 0 deletions test/integration/zygote.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using Zygote, Reactant, Enzyme, Test

sumabs2(x) = sum(abs2, x)

@testset "Zygote" begin
@testset "Zygote.gradient" begin
x = Reactant.to_rarray(rand(Float32, 32, 10))

zyg_grad = @jit Zygote.gradient(sumabs2, x)
enz_grad = @jit Enzyme.gradient(Reverse, Const(sumabs2), x)
@test zyg_grad[1] isa Reactant.ConcreteRArray
@test enz_grad[1] ≈ zyg_grad[1]

@testset "Disable Overlay" begin
@test_throws Zygote.CompileError Reactant.with_config(;
overlay_zygote_calls=false
) do
@jit Zygote.gradient(sumabs2, x)
end
end
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
@safetestset "Python" include("integration/python.jl")
@safetestset "Optimisers" include("integration/optimisers.jl")
@safetestset "FillArrays" include("integration/fillarrays.jl")
@safetestset "Zygote" include("integration/zygote.jl")
end

if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks"
Expand Down
Loading