Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Random123 = "74087812-796a-5b5d-8853-05524746bad3"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[sources]
ReactantCore = {path = "lib/ReactantCore"}
Expand All @@ -67,6 +68,7 @@ ReactantRandom123Ext = "Random123"
ReactantSpecialFunctionsExt = "SpecialFunctions"
ReactantStatisticsExt = "Statistics"
ReactantYaoBlocksExt = "YaoBlocks"
ReactantZygoteExt = "Zygote"

[compat]
AbstractFFTs = "1.5"
Expand Down Expand Up @@ -107,6 +109,7 @@ Sockets = "1.10"
SpecialFunctions = "2.4"
Statistics = "1.10"
YaoBlocks = "0.13, 0.14"
Zygote = "0.7"
julia = "1.10"
unzip_jll = "6"

Expand Down
28 changes: 28 additions & 0 deletions ext/ReactantZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
module ReactantZygoteExt

using Reactant:
Reactant, CallWithReactant, @reactant_overlay, use_overlayed_version, call_with_reactant
using Zygote: Zygote
using Enzyme: Enzyme, Reverse, Active, Const, Duplicated

# TODO: overload the following as well
# - Zygote.pullback
# - Zygote.jacobian
# - Zygote.hessian

@reactant_overlay function Zygote.gradient(f::F, args...) where {F}
# TODO: check `f` as well once #1642 is merged
if use_overlayed_version(args)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do this, we should quite aggressively yell that we're gonig to do this -- I would even be okay saying to do this for each call [not even each callsite]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also kind of debating if we want to have this behind a feature flag as well [as perhaps it is useful to compare the performance of zygote as a frontend vs us inside the compiler]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is what I was testing first, but we need to fix our broadcasting quirks before we can work through ChainRules rrules.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah....we should definitely fix that

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will error inside Zygote but we have an option

Reactant.with_config(; overlay_zygote_calls=false) do
    @jit Zygote.gradient(sumabs2, x)
end

@warn "Reactant doesn't support using Zygote for computing gradients. Replacing \
`Zygote.gradient` with `Enzyme.autodiff` call. Please update your code to \
not use `Zygote.gradient` inside `Reactant.@compile`." maxlog = 1
dargs = map(Enzyme.make_zero, args)
duplicated = map(Duplicated, args, dargs)
Reactant.overload_autodiff(Reverse, Const(f), Active, duplicated...)
return dargs
else
return Base.inferencebarrier(Zygote.gradient)(CallWithReactant(f), args...)
end
end

end
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Adapt = "4.1"
Expand Down Expand Up @@ -65,6 +66,7 @@ StableRNGs = "1"
Statistics = "1.10"
StatsBase = "0.34"
Test = "1.10"
Zygote = "0.7"

[extras]
Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"
14 changes: 14 additions & 0 deletions test/integration/zygote.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using Zygote, Reactant, Enzyme, Test

sumabs2(x) = sum(abs2, x)

@testset "Zygote" begin
@testset "Zygote.gradient" begin
x = Reactant.to_rarray(rand(Float32, 32, 10))

zyg_grad = @jit Zygote.gradient(sumabs2, x)
enz_grad = @jit Enzyme.gradient(Reverse, Const(sumabs2), x)
@test zyg_grad[1] isa Reactant.ConcreteRArray
@test enz_grad[1] ≈ zyg_grad[1]
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
@safetestset "Python" include("integration/python.jl")
@safetestset "Optimisers" include("integration/optimisers.jl")
@safetestset "FillArrays" include("integration/fillarrays.jl")
@safetestset "Zygote" include("integration/zygote.jl")
end

if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks"
Expand Down
Loading