Skip to content

Commit e1ee0a9

Browse files
Merge pull request #291 from ErikQQY/qqy/more_AD
Support Enzyme and Mooncake in AD backends
2 parents 61fb9ed + cf39fc8 commit e1ee0a9

File tree

23 files changed

+1257
-170
lines changed

23 files changed

+1257
-170
lines changed

docs/src/basics/autodiff.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
!!! note
44

5-
We support ForwardDiff.jl, FiniteDiff.jl and PolyesterForwardDiff.jl(PolyesterForwardDiff only for collocation methods) via DifferentiationInterface.jl. Please refer to
5+
We support ForwardDiff.jl, FiniteDiff.jl, Enzyme.jl, Mooncake.jl and PolyesterForwardDiff.jl(PolyesterForwardDiff only for collocation methods) via DifferentiationInterface.jl. Please refer to
66
the [backends page](https://juliadiff.org/DifferentiationInterface.jl/DifferentiationInterface/stable/explanation/backends/)
77
for more information.
88

@@ -25,3 +25,9 @@ In BoundaryValueDiffEq.jl, we require AD to obtain the Jacobian of the loss func
2525
- [`AutoPolyesterForwardDiff`](@extref ADTypes.AutoPolyesterForwardDiff): Might be faster
2626
than [`AutoForwardDiff`](@extref ADTypes.AutoForwardDiff) for large problems. Requires
2727
`PolyesterForwardDiff.jl` to be installed and loaded.
28+
- [`AutoEnzyme(; mode = Enzyme.Forward)`](@extref ADTypes.AutoEnzyme): Source transformation forward mode AD.
29+
30+
## Summary of Reverse Mode AD Backends
31+
32+
- [`AutoEnzyme(; mode = Enzyme.Reverse)`](@extref ADTypes.AutoEnzyme): Source transformation reverse mode AD.
33+
- [`AutoMooncake(; config = nothing)`](@extref ADTypes.AutoMooncake): Source transformation reverse mode AD.

docs/src/devdocs/internal_interfaces.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22

33
## Solvers
44

5-
```julia
5+
```@docs
66
BoundaryValueDiffEqCore.AbstractBoundaryValueDiffEqAlgorithm
77
```

lib/BoundaryValueDiffEqAscher/src/BoundaryValueDiffEqAscher.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@ module BoundaryValueDiffEqAscher
22

33
using ADTypes: ADTypes, AutoSparse, AutoForwardDiff
44
using AlmostBlockDiagonals: AlmostBlockDiagonals, IntermediateAlmostBlockDiagonal
5+
56
using BoundaryValueDiffEqCore: AbstractBoundaryValueDiffEqAlgorithm,
67
AbstractBoundaryValueDiffEqCache, BVPJacobianAlgorithm,
78
__extract_problem_details, concrete_jacobian_algorithm,
89
__Fix3, __concrete_nonlinearsolve_algorithm,
910
__internal_nlsolve_problem, __vec, __vec_f, __vec_f!,
1011
__vec_bc, __vec_bc!, __extract_mesh, get_dense_ad,
11-
__get_bcresid_prototype, __split_kwargs
12+
__get_bcresid_prototype, __split_kwargs,
13+
__default_nonsparse_ad
1214

1315
using ConcreteStructs: @concrete
1416
using DiffEqBase: DiffEqBase

lib/BoundaryValueDiffEqAscher/src/algorithms.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ for stage in (1, 2, 3, 4, 5, 6, 7)
5555
end
5656
end
5757

58-
function concretize_jacobian_algorithm(alg::AbstractAscher, prob)
59-
@set! alg.jac_alg = concrete_jacobian_algorithm(alg.jac_alg, prob, alg)
60-
return alg
58+
function BoundaryValueDiffEqCore.concrete_jacobian_algorithm(
59+
jac_alg::BVPJacobianAlgorithm, prob::BVProblem, alg::AbstractAscher)
60+
return BVPJacobianAlgorithm(__default_nonsparse_ad(prob.u0))
6161
end

lib/BoundaryValueDiffEqCore/src/BoundaryValueDiffEqCore.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ using SciMLBase: SciMLBase, AbstractBVProblem, AbstractDiffEqInterpolation,
1818
StandardBVProblem, StandardSecondOrderBVProblem, __solve, _unwrap_val
1919
using Setfield: @set!, @set
2020
using SparseArrays: sparse
21-
using SparseConnectivityTracer: TracerLocalSparsityDetector
21+
using SparseConnectivityTracer: SparseConnectivityTracer, TracerLocalSparsityDetector
2222
using SparseMatrixColorings: GreedyColoringAlgorithm
2323

2424
@reexport using NonlinearSolveFirstOrder, SciMLBase

lib/BoundaryValueDiffEqCore/src/types.jl

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,25 @@ function concrete_jacobian_algorithm(
6868
return concrete_jacobian_algorithm(jac_alg, prob.problem_type, prob, alg)
6969
end
7070

71+
# For multi-point BVP, we only care about bc_diffmode and nonbc_diffmode
7172
function concrete_jacobian_algorithm(
72-
jac_alg::BVPJacobianAlgorithm, prob_type, prob::BVProblem, alg)
73+
jac_alg::BVPJacobianAlgorithm, prob_type::StandardBVProblem, prob::BVProblem, alg)
7374
u0 = __extract_u0(prob.u0, prob.p, first(prob.tspan))
74-
diffmode = jac_alg.diffmode === nothing ? __default_sparse_ad(u0) : jac_alg.diffmode
75-
bc_diffmode = jac_alg.bc_diffmode === nothing ?
76-
(prob_type isa TwoPointBVProblem ? __default_bc_sparse_ad :
77-
__default_nonsparse_ad)(u0) : jac_alg.bc_diffmode
75+
bc_diffmode = jac_alg.bc_diffmode === nothing ? __default_bc_sparse_ad(u0) :
76+
jac_alg.bc_diffmode
7877
nonbc_diffmode = jac_alg.nonbc_diffmode === nothing ? __default_sparse_ad(u0) :
7978
jac_alg.nonbc_diffmode
79+
diffmode = jac_alg.diffmode === nothing ? nothing : jac_alg.diffmode
80+
return BVPJacobianAlgorithm(bc_diffmode, nonbc_diffmode, diffmode)
81+
end
82+
83+
# For two-point BVP, we only care about diffmode
84+
function concrete_jacobian_algorithm(
85+
jac_alg::BVPJacobianAlgorithm, prob_type::TwoPointBVProblem, prob::BVProblem, alg)
86+
u0 = __extract_u0(prob.u0, prob.p, first(prob.tspan))
87+
diffmode = jac_alg.diffmode === nothing ? __default_sparse_ad(u0) : jac_alg.diffmode
88+
bc_diffmode = jac_alg.bc_diffmode === nothing ? nothing : jac_alg.bc_diffmode
89+
nonbc_diffmode = jac_alg.nonbc_diffmode === nothing ? nothing : jac_alg.nonbc_diffmode
8090
return BVPJacobianAlgorithm(bc_diffmode, nonbc_diffmode, diffmode)
8191
end
8292

@@ -151,22 +161,13 @@ end
151161
__needs_diffcache(jac_alg.nonbc_diffmode)
152162
end
153163

154-
# We don't need to always allocate a DiffCache. This works around that.
155-
@concrete struct FakeDiffCache
156-
du
157-
end
158-
159164
function __maybe_allocate_diffcache(x, chunksize, jac_alg)
160-
return __needs_diffcache(jac_alg) ? DiffCache(x, chunksize) : FakeDiffCache(x)
165+
return __needs_diffcache(jac_alg) ? DiffCache(x, chunksize) : x
161166
end
162167
__maybe_allocate_diffcache(x::DiffCache, chunksize) = DiffCache(zero(x.du), chunksize)
163-
__maybe_allocate_diffcache(x::FakeDiffCache, _) = FakeDiffCache(zero(x.du))
164-
165-
const MaybeDiffCache = Union{DiffCache, FakeDiffCache}
166168

167169
## get_tmp shows a warning as it should on cache exapansion, this behavior however is
168170
## expected for adaptive BVP solvers so we write our own `get_tmp` and drop the warning logs
169-
@inline get_tmp(dc::FakeDiffCache, u) = dc.du
170171

171172
@inline function get_tmp(dc, u)
172173
return Logging.with_logger(Logging.NullLogger()) do
@@ -180,4 +181,8 @@ struct NoDiffCacheNeeded end
180181

181182
@inline __cache_trait(::AutoForwardDiff) = DiffCacheNeeded()
182183
@inline __cache_trait(ad::AutoSparse) = __cache_trait(ADTypes.dense_ad(ad))
184+
@inline function __cache_trait(jac_alg::BVPJacobianAlgorithm)
185+
isnothing(jac_alg.diffmode) ? __cache_trait(jac_alg.nonbc_diffmode) :
186+
__cache_trait(jac_alg.diffmode)
187+
end
183188
@inline __cache_trait(_) = NoDiffCacheNeeded()

lib/BoundaryValueDiffEqCore/src/utils.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
recursive_length(x::Vector{<:AbstractArray}) = sum(length, x)
2-
recursive_length(x::Vector{<:MaybeDiffCache}) = sum(xᵢ -> length(xᵢ.du), x)
2+
recursive_length(x::Vector{<:DiffCache}) = sum(xᵢ -> length(xᵢ.u), x)
33

44
function recursive_flatten(x::Vector{<:AbstractArray})
55
y = zero(first(x), recursive_length(x))
@@ -37,7 +37,7 @@ end
3737
return y
3838
end
3939

40-
@views function recursive_unflatten!(y::Vector{<:MaybeDiffCache}, x::AbstractVector)
40+
@views function recursive_unflatten!(y::Vector{<:DiffCache}, x::AbstractVector)
4141
return recursive_unflatten!(get_tmp.(y, (x,)), x)
4242
end
4343

@@ -230,7 +230,7 @@ end
230230
__resize!(::Nothing, n, _) = nothing
231231
__resize!(::Nothing, n, _, _) = nothing
232232

233-
function __resize!(x::AbstractVector{<:MaybeDiffCache}, n, M)
233+
function __resize!(x::AbstractVector{<:DiffCache}, n, M)
234234
N = n - length(x)
235235
N == 0 && return x
236236
if N > 0
@@ -567,6 +567,14 @@ function _sparse_like(I, J, x::AbstractArray, m = maximum(I), n = maximum(J))
567567
return sparse(I′, J′, V, m, n)
568568
end
569569

570+
nodual_value(x) = x
571+
nodual_value(x::ForwardDiff.Dual) = ForwardDiff.value(x)
572+
nodual_value(x::AbstractArray{<:ForwardDiff.Dual}) = map(ForwardDiff.value, x)
573+
nodual_value(x::SparseConnectivityTracer.Dual) = SparseConnectivityTracer.primal(x)
574+
function nodual_value(x::AbstractArray{<:SparseConnectivityTracer.Dual})
575+
map(SparseConnectivityTracer.primal, x)
576+
end
577+
570578
function __split_kwargs(; abstol, adaptive, controller, kwargs...)
571579
return ((abstol, adaptive, controller), (; abstol, adaptive, kwargs...))
572580
end

lib/BoundaryValueDiffEqFIRK/Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ ConcreteStructs = "0.2.3"
3636
DiffEqBase = "6.167"
3737
DiffEqDevTools = "2.44"
3838
DifferentiationInterface = "0.6.42"
39+
Enzyme = "0.13.33"
3940
FastAlmostBandedMatrices = "0.1.4"
4041
FastClosures = "0.3.2"
4142
ForwardDiff = "0.10.38, 1"
@@ -44,6 +45,7 @@ InteractiveUtils = "<0.0.1, 1"
4445
JET = "0.9.18"
4546
LinearAlgebra = "1.10"
4647
LinearSolve = "2.36.2, 3"
48+
Mooncake = "0.4.108"
4749
OrdinaryDiffEqRosenbrock = "1"
4850
PreallocationTools = "0.4.24"
4951
PrecompileTools = "1.2"
@@ -62,10 +64,12 @@ julia = "1.10"
6264
[extras]
6365
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
6466
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
67+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
6568
Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
6669
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
6770
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
6871
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
72+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
6973
OrdinaryDiffEqRosenbrock = "43230ef6-c299-4910-a778-202eb28ce4ce"
7074
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
7175
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
@@ -74,4 +78,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
7478
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
7579

7680
[targets]
77-
test = ["Aqua", "DiffEqDevTools", "Hwloc", "InteractiveUtils", "JET", "LinearSolve", "OrdinaryDiffEqRosenbrock", "Random", "ReTestItems", "RecursiveArrayTools", "StaticArrays", "Test"]
81+
test = ["Aqua", "DiffEqDevTools", "Enzyme", "Hwloc", "InteractiveUtils", "JET", "LinearSolve", "OrdinaryDiffEqRosenbrock", "Mooncake", "Random", "ReTestItems", "RecursiveArrayTools", "StaticArrays", "Test"]

lib/BoundaryValueDiffEqFIRK/src/BoundaryValueDiffEqFIRK.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,20 @@ using BoundaryValueDiffEqCore: AbstractBoundaryValueDiffEqAlgorithm,
99
__concrete_nonlinearsolve_algorithm, diff!, EvalSol,
1010
concrete_jacobian_algorithm, eval_bc_residual, interval,
1111
eval_bc_residual!, get_tmp, __maybe_matmul!, __resize!,
12-
__extract_problem_details, __initial_guess,
12+
__extract_problem_details, __initial_guess, nodual_value,
13+
__maybe_allocate_diffcache, __restructure_sol,
14+
__get_bcresid_prototype, __vec, __vec_f, __vec_f!, __vec_bc,
15+
__vec_bc!, recursive_flatten_twopoint!,
16+
__internal_nlsolve_problem, __extract_mesh, __extract_u0,
1317
__default_coloring_algorithm, __maybe_allocate_diffcache,
1418
__restructure_sol, __get_bcresid_prototype, safe_similar,
15-
__vec, __vec_f, __vec_f!, __vec_bc, __vec_bc!,
19+
__vec, __vec_f, __vec_f!, __vec_bc, __vec_bc!, __cache_trait,
1620
recursive_flatten_twopoint!, __internal_nlsolve_problem,
17-
MaybeDiffCache, __extract_mesh, __extract_u0,
18-
__has_initial_guess, __initial_guess_length,
19-
__initial_guess_on_mesh, __flatten_initial_guess,
20-
__split_kwargs, __build_solution, __Fix3, _sparse_like,
21-
get_dense_ad
21+
__extract_mesh, __extract_u0, DiffCacheNeeded,
22+
NoDiffCacheNeeded, __has_initial_guess,
23+
__initial_guess_length, __initial_guess_on_mesh,
24+
__flatten_initial_guess, __build_solution, __Fix3,
25+
__split_kwargs, _sparse_like, get_dense_ad
2226

2327
using ConcreteStructs: @concrete
2428
using DiffEqBase: DiffEqBase

0 commit comments

Comments
 (0)