diff --git a/Project.toml b/Project.toml index ab0bf01..1cb6b2c 100644 --- a/Project.toml +++ b/Project.toml @@ -9,11 +9,13 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" [weakdeps] +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" [extensions] +PreallocationToolsEnzymeCoreExt = "EnzymeCore" PreallocationToolsForwardDiffExt = "ForwardDiff" PreallocationToolsReverseDiffExt = "ReverseDiff" PreallocationToolsSparseConnectivityTracerExt = "SparseConnectivityTracer" @@ -23,6 +25,8 @@ ADTypes = "1.16" Adapt = "4.3.0" Aqua = "0.8.11" ArrayInterface = "7.19.0" +Enzyme = "0.13" +EnzymeCore = "0.8" ForwardDiff = "0.10.38, 1.0.1" LabelledArrays = "1.16.0" LinearAlgebra = "1.10" @@ -44,6 +48,8 @@ julia = "1.10" [extras] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -61,4 +67,4 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "ADTypes", "ForwardDiff", "Random", "LabelledArrays", "LinearAlgebra", "OrdinaryDiffEq", "Test", "RecursiveArrayTools", "Pkg", "SafeTestsets", "Optimization", "OptimizationOptimJL", "SparseArrays", "Symbolics", "SparseConnectivityTracer"] +test = ["Aqua", "ADTypes", "Enzyme", "ForwardDiff", "Random", "LabelledArrays", "LinearAlgebra", "OrdinaryDiffEq", "Test", "RecursiveArrayTools", "Pkg", "SafeTestsets", "Optimization", "OptimizationOptimJL", "SparseArrays", "Symbolics", "SparseConnectivityTracer"] diff --git a/ext/PreallocationToolsEnzymeCoreExt.jl b/ext/PreallocationToolsEnzymeCoreExt.jl new file mode 100644 index 0000000..1310913 --- /dev/null +++ b/ext/PreallocationToolsEnzymeCoreExt.jl @@ -0,0 +1,51 @@ +module PreallocationToolsEnzymeCoreExt + +using PreallocationTools +import EnzymeCore: EnzymeRules, Const, Duplicated + +# TODO: Support Batched mode, on 1.11 +# if VERSION >= v"1.11.0" +# function tuple_of_vectors(M::Matrix{T}, shape) where {T} +# n, m = size(M) +# return ntuple(m) do i +# vec = Base.wrap(Array, memoryref(M.ref, (i - 1) * n + 1), (n,)) +# reshape(vec, shape) +# end +# end +# end + +# TODO: Support reverse mode? + +function EnzymeRules.forward(config, func::Const{typeof(PreallocationTools.get_tmp)}, ::Type{<:Duplicated}, + dc::Duplicated{<:PreallocationTools.DiffCache}, u::Union{Const{T}, Duplicated{T}}) where {T} + du = PreallocationTools.get_tmp(dc.val, u.val) + ddu = PreallocationTools.get_tmp(dc.dval, u.val) + Duplicated(du, ddu) +end + +function EnzymeRules.forward(config, func::Const{typeof(PreallocationTools.get_tmp)}, ::Type{<:Duplicated}, + dc::Const{<:PreallocationTools.DiffCache}, u::Union{Const{T}, Duplicated{T}}) where {T} + dc = dc.val + du = PreallocationTools.get_tmp(dc, u.val) + + # ddu = if isbitstype(T) + # nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du) + # if nelem > length(dc.dual_du) + # PreallocationTools.enlargediffcache!(dc, nelem) + # end + # PreallocationTools._restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem))) + # else + # PreallocationTools._restructure(dc.du, zeros(T, size(dc.du))) + # end + + # Enzyme requires that Duplicated types have the same type and structure + # the above code fails since it creates something like a `Base.ReshapedArray{Float64, 2, SubArray{…}, Tuple{}})` + + # TODO: How does this interact with Enzyme over ForwardDiff? + ddu = dc.dual_du + resize!(ddu, length(du)) + + Duplicated(du, reshape(ddu, size(du))) +end + +end diff --git a/test/enzyme.jl b/test/enzyme.jl new file mode 100644 index 0000000..1f2ceff --- /dev/null +++ b/test/enzyme.jl @@ -0,0 +1,42 @@ +module TestEnzyme + using Enzyme + using PreallocationTools + using ForwardDiff + + const randmat = rand(5, 3) + + + function claytonsample!(sto, τ, α; randmat = randmat) + sto = get_tmp(sto, τ) + sto .= randmat + τ == 0 && return sto + + n = size(sto, 1) + for i in 1:n + v = sto[i, 2] + u = sto[i, 1] + sto[i, 1] = (1 - u^(-τ) + u^(-τ) * v^(-(τ / (1 + τ))))^(-1 / τ) * α + sto[i, 2] = (1 - u^(-τ) + u^(-τ) * v^(-(τ / (1 + τ))))^(-1 / τ) + end + return sto + end + + sto = similar(randmat) + stod = DiffCache(sto) + + d_sto_fwd = ForwardDiff.derivative(τ -> claytonsample!(stod, τ, 0.0), 0.3) + d_sto_enz = Enzyme.autodiff(Forward, claytonsample!, Const(stod), Duplicated(0.3, 1.0), Const(0.0)) |> only + + @test d_sto_enz ≈ d_sto_fwd + + d_sto_enz2 = Enzyme.autodiff(Forward, claytonsample!, Duplicated(stod, Enzyme.make_zero(stod)), Duplicated(0.3, 1.0), Const(0.0)) |> only + @test d_sto_enz2 ≈ d_sto_fwd + + d_sto_enz3 = Enzyme.autodiff(Forward, claytonsample!, Const(stod), Const(0.3), Const(0.0)) |> only + @test all(d_sto_enz3 .== 0.0) + + d_sto_enz4 = Enzyme.autodiff(Forward, claytonsample!, Const(stod), Const(0.3), Duplicated(1.0, 1.0)) |> only + d_sto_fwd4 = reshape(ForwardDiff.jacobian(x -> claytonsample!(stod, x[1], x[2]), [0.3; 0.0])[:, 2], size(sto)) + @test d_sto_enz4 ≈ d_sto_fwd4 +end # TestEnzyme + diff --git a/test/runtests.jl b/test/runtests.jl index c6e5e99..632ac69 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -17,6 +17,7 @@ if GROUP == "All" || GROUP == "Core" @safetestset "DiffCache Nested Duals" include("core_nesteddual.jl") @safetestset "DiffCache Sparsity Support" include("sparsity_support.jl") @safetestset "DiffCache with SparseConnectivityTracer" include("sparse_connectivity_tracer.jl") + @safetestset "DiffCache with Enzyme" include("enzyme.jl") @safetestset "LazyBufferCache" include("lbc.jl") @safetestset "GeneralLazyBufferCache" include("general_lbc.jl") @safetestset "Zero and Copy Dispatches" include("test_zero_copy.jl")