diff --git a/Project.toml b/Project.toml index 737ea28018..87bac86d19 100644 --- a/Project.toml +++ b/Project.toml @@ -46,6 +46,7 @@ Random123 = "74087812-796a-5b5d-8853-05524746bad3" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [sources] ReactantCore = {path = "lib/ReactantCore"} @@ -67,6 +68,7 @@ ReactantRandom123Ext = "Random123" ReactantSpecialFunctionsExt = "SpecialFunctions" ReactantStatisticsExt = "Statistics" ReactantYaoBlocksExt = "YaoBlocks" +ReactantZygoteExt = "Zygote" [compat] AbstractFFTs = "1.5" @@ -107,6 +109,7 @@ Sockets = "1.10" SpecialFunctions = "2.4" Statistics = "1.10" YaoBlocks = "0.13, 0.14" +Zygote = "0.7" julia = "1.10" unzip_jll = "6" diff --git a/docs/src/api/config.md b/docs/src/api/config.md index a3915c078c..79024b99f7 100644 --- a/docs/src/api/config.md +++ b/docs/src/api/config.md @@ -33,6 +33,11 @@ Reactant.PrecisionConfig Reactant.DotGeneralAlgorithm ``` +### Zygote Overlay + +- `OVERLAY_ZYGOTE_CALLS`: Whether to overlay `Zygote.gradient` calls with `Enzyme.autodiff` + calls. + ## Environment Variables The following environment variables can be used to configure Reactant. diff --git a/ext/ReactantZygoteExt.jl b/ext/ReactantZygoteExt.jl new file mode 100644 index 0000000000..f50229ea17 --- /dev/null +++ b/ext/ReactantZygoteExt.jl @@ -0,0 +1,29 @@ +module ReactantZygoteExt + +using Reactant: + Reactant, CallWithReactant, @reactant_overlay, use_overlayed_version, call_with_reactant +using Zygote: Zygote +using Enzyme: Enzyme, Reverse, Active, Const, Duplicated + +# TODO: overload the following as well +# - Zygote.pullback +# - Zygote.jacobian +# - Zygote.hessian + +@reactant_overlay function Zygote.gradient(f::F, args...) where {F} + # TODO: check `f` as well once #1642 is merged + if Reactant.OVERLAY_ZYGOTE_CALLS[] && use_overlayed_version(args) + @warn "Reactant doesn't support using Zygote for computing gradients. Replacing \ + `Zygote.gradient` with `Enzyme.autodiff` call. Please update your code to \ + not use `Zygote.gradient` and instead use `Enzyme.gradient` inside \ + `Reactant.@compile`. If this behavior is undesirable, set the \ + `overlay_zygote_calls` scoped value via `Reactant.with_config` to \ + `false`.\n\nReactant can remove this switching without any breaking change \ + and hence reliance on this behavior is strongly discouraged." + return Enzyme.gradient(Reverse, Const(f), args...) + else + return Base.inferencebarrier(Zygote.gradient)(CallWithReactant(f), args...) + end +end + +end diff --git a/src/Configuration.jl b/src/Configuration.jl index 5b0eaa00af..35149345e9 100644 --- a/src/Configuration.jl +++ b/src/Configuration.jl @@ -30,6 +30,11 @@ scope will use the provided values. or [`PrecisionConfig`](@ref). Defaults to `PrecisionConfig.DEFAULT`. - `convolution_precision`: Precision for `stablehlo.convolution`. Can be `nothing`, or [`PrecisionConfig`](@ref). Defaults to `PrecisionConfig.DEFAULT`. + +### Zygote Overlay + + - `overlay_zygote_calls`: Whether to overlay `Zygote.gradient` calls with + `Enzyme.autodiff` calls. Defaults to `true`. """ function with_config( f; @@ -38,6 +43,7 @@ function with_config( convolution_precision=missing, lower_partialsort_to_approx_top_k=missing, fallback_approx_top_k_lowering=missing, + overlay_zygote_calls=missing, ) config_vars = () dot_general_algorithm !== missing && @@ -58,6 +64,8 @@ function with_config( FALLBACK_APPROX_TOP_K_LOWERING => fallback_approx_top_k_lowering, ) ) + overlay_zygote_calls !== missing && + (config_vars = (config_vars..., OVERLAY_ZYGOTE_CALLS => overlay_zygote_calls)) return ScopedValues.with(f, config_vars...) end @@ -379,3 +387,6 @@ function DotGeneralAlgorithm( return nothing end + +# Overlay Zygote.jl +const OVERLAY_ZYGOTE_CALLS = ScopedValue(true) diff --git a/test/Project.toml b/test/Project.toml index bb9d5f4cfc..a75f670c27 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -31,6 +31,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] Adapt = "4.1" @@ -65,6 +66,7 @@ StableRNGs = "1" Statistics = "1.10" StatsBase = "0.34" Test = "1.10" +Zygote = "0.7" [extras] Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0" diff --git a/test/integration/zygote.jl b/test/integration/zygote.jl new file mode 100644 index 0000000000..f442cef391 --- /dev/null +++ b/test/integration/zygote.jl @@ -0,0 +1,22 @@ +using Zygote, Reactant, Enzyme, Test + +sumabs2(x) = sum(abs2, x) + +@testset "Zygote" begin + @testset "Zygote.gradient" begin + x = Reactant.to_rarray(rand(Float32, 32, 10)) + + zyg_grad = @jit Zygote.gradient(sumabs2, x) + enz_grad = @jit Enzyme.gradient(Reverse, Const(sumabs2), x) + @test zyg_grad[1] isa Reactant.ConcreteRArray + @test enz_grad[1] ≈ zyg_grad[1] + + @testset "Disable Overlay" begin + @test_throws Zygote.CompileError Reactant.with_config(; + overlay_zygote_calls=false + ) do + @jit Zygote.gradient(sumabs2, x) + end + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index d52ebebe90..ec80b802c1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -51,6 +51,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Python" include("integration/python.jl") @safetestset "Optimisers" include("integration/optimisers.jl") @safetestset "FillArrays" include("integration/fillarrays.jl") + @safetestset "Zygote" include("integration/zygote.jl") end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks"