Skip to content

Commit f701e0e

Browse files
Merge pull request #2567 from jClugstor/use_DI
Use DifferentiationInterface for AD in Implicit Solvers
2 parents 57e9bcd + 57ea0bc commit f701e0e

36 files changed

+914
-465
lines changed

Project.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,13 @@ SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
6868
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
6969
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
7070
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
71-
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
7271
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
7372
StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718"
7473
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
7574
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
7675

7776
[compat]
78-
ADTypes = "0.2, 1"
77+
ADTypes = "1.13"
7978
Adapt = "3.0, 4"
8079
ArrayInterface = "7"
8180
DataStructures = "0.18"
@@ -138,7 +137,6 @@ SciMLOperators = "0.3"
138137
SciMLStructures = "1"
139138
SimpleNonlinearSolve = "1, 2"
140139
SimpleUnPack = "1"
141-
SparseDiffTools = "2"
142140
Static = "0.8, 1"
143141
StaticArrayInterface = "1.2"
144142
StaticArrays = "1.0"
@@ -153,6 +151,7 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
153151
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
154152
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
155153
ElasticArrays = "fdbdab4c-e67f-52f5-8c3f-e7b388dad3d4"
154+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
156155
IncompleteLU = "40713840-3770-5561-ab4c-a76e7d0d7895"
157156
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
158157
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
@@ -167,11 +166,13 @@ RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
167166
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
168167
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
169168
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
169+
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
170+
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
170171
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
171172
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
172173
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
173174
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
174175
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
175176

176177
[targets]
177-
test = ["Calculus", "ComponentArrays", "Symbolics", "AlgebraicMultigrid", "IncompleteLU", "DiffEqCallbacks", "DiffEqDevTools", "ODEProblemLibrary", "ElasticArrays", "InteractiveUtils", "ParameterizedFunctions", "PoissonRandom", "Printf", "Random", "ReverseDiff", "SafeTestsets", "SparseArrays", "Statistics", "StructArrays", "Test", "Unitful", "ModelingToolkit", "Pkg", "NLsolve", "RecursiveFactorization"]
178+
test = ["Calculus", "ComponentArrays", "Symbolics", "AlgebraicMultigrid", "IncompleteLU", "DiffEqCallbacks", "DiffEqDevTools", "ODEProblemLibrary", "ElasticArrays", "InteractiveUtils", "ParameterizedFunctions", "PoissonRandom", "Printf", "Random", "ReverseDiff", "SafeTestsets", "SparseArrays", "Statistics", "StructArrays", "Test", "Unitful", "ModelingToolkit", "Pkg", "NLsolve", "RecursiveFactorization", "Enzyme", "SparseConnectivityTracer", "SparseMatrixColorings"]

lib/OrdinaryDiffEqBDF/Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,14 @@ julia = "1.10"
5050

5151
[extras]
5252
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
53+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
5354
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
55+
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
5456
ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5"
5557
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
5658
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
5759
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
5860
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5961

6062
[targets]
61-
test = ["DiffEqDevTools", "ForwardDiff", "Random", "SafeTestsets", "Test", "ODEProblemLibrary", "StaticArrays"]
63+
test = ["DiffEqDevTools", "ForwardDiff", "Random", "SafeTestsets", "Test", "ODEProblemLibrary", "StaticArrays", "Enzyme", "LinearSolve"]

lib/OrdinaryDiffEqBDF/test/bdf_convergence_tests.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# This definitely needs cleaning
2-
using OrdinaryDiffEqBDF, ODEProblemLibrary, DiffEqDevTools
2+
using OrdinaryDiffEqBDF, ODEProblemLibrary, DiffEqDevTools, ADTypes, Enzyme, LinearSolve
33
using OrdinaryDiffEqNonlinearSolve: NLFunctional, NLAnderson, NonlinearSolveAlg
44
using Test, Random
55
Random.seed!(100)
@@ -39,6 +39,27 @@ dts3 = 1 .// 2 .^ (12:-1:7)
3939
@test sim.𝒪est[:l2]1 atol=testTol
4040
@test sim.𝒪est[:l∞]1 atol=testTol
4141

42+
sim = test_convergence(dts, prob, QNDF1(autodiff = AutoFiniteDiff()))
43+
@test sim.𝒪est[:final]1 atol=testTol
44+
@test sim.𝒪est[:l2]1 atol=testTol
45+
@test sim.𝒪est[:l∞]1 atol=testTol
46+
47+
sim = test_convergence(dts,
48+
prob,
49+
QNDF1(autodiff = AutoEnzyme(mode = set_runtime_activity(Enzyme.Forward),
50+
function_annotation = Enzyme.Const)))
51+
@test sim.𝒪est[:final]1 atol=testTol
52+
@test sim.𝒪est[:l2]1 atol=testTol
53+
@test sim.𝒪est[:l∞]1 atol=testTol
54+
55+
sim = test_convergence(dts,
56+
prob,
57+
QNDF1(autodiff = AutoEnzyme(mode = set_runtime_activity(Enzyme.Forward),
58+
function_annotation = Enzyme.Const), linsolve = LinearSolve.KrylovJL()))
59+
@test sim.𝒪est[:final]1 atol=testTol
60+
@test sim.𝒪est[:l2]1 atol=testTol
61+
@test sim.𝒪est[:l∞]1 atol=testTol
62+
4263
sim = test_convergence(dts3, prob, QNDF2())
4364
@test sim.𝒪est[:final]2 atol=testTol
4465
@test sim.𝒪est[:l2]2 atol=testTol

lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ import DiffEqBase: calculate_residuals,
7272

7373
import Polyester
7474
using MacroTools, Adapt
75-
import ADTypes: AutoFiniteDiff, AutoForwardDiff, AbstractADType
75+
import ADTypes: AutoFiniteDiff, AutoForwardDiff, AbstractADType, AutoSparse
7676
import Accessors: @reset
7777

7878
using SciMLStructures: canonicalize, Tunable, isscimlstructure

lib/OrdinaryDiffEqCore/src/alg_utils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ _get_fwd_chunksize(::Type{<:AutoForwardDiff{CS}}) where {CS} = Val(CS)
176176
_get_fwd_chunksize_int(::Type{<:AutoForwardDiff{CS}}) where {CS} = CS
177177
_get_fwd_chunksize(AD) = Val(0)
178178
_get_fwd_chunksize_int(AD) = 0
179+
_get_fwd_chunksize_int(::AutoForwardDiff{CS}) where {CS} = CS
179180
_get_fwd_tag(::AutoForwardDiff{CS, T}) where {CS, T} = T
180181

181182
_get_fdtype(::AutoFiniteDiff{T1}) where {T1} = T1

lib/OrdinaryDiffEqCore/src/algorithms.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,16 @@ function DiffEqBase.remake(
5959
},
6060
DAEAlgorithm{CS, AD, FDT, ST, CJ}};
6161
kwargs...) where {CS, AD, FDT, ST, CJ}
62+
63+
if haskey(kwargs, :autodiff) && kwargs[:autodiff] isa AutoForwardDiff
64+
chunk_size = _get_fwd_chunksize(kwargs[:autodiff])
65+
else
66+
chunk_size = Val{CS}()
67+
end
68+
6269
T = SciMLBase.remaker_of(thing)
6370
T(; SciMLBase.struct_as_namedtuple(thing)...,
64-
chunk_size = Val{CS}(), autodiff = thing.autodiff, standardtag = Val{ST}(),
71+
chunk_size = chunk_size, autodiff = thing.autodiff, standardtag = Val{ST}(),
6572
concrete_jac = CJ === nothing ? CJ : Val{CJ}(),
6673
kwargs...)
6774
end

lib/OrdinaryDiffEqCore/src/misc_utils.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,10 +158,11 @@ end
158158
function _process_AD_choice(
159159
ad_alg::AutoForwardDiff{CS}, ::Val{CS2}, ::Val{FD}) where {CS, CS2, FD}
160160
# Non-default `chunk_size`
161-
if CS2 != 0
161+
if (CS2 != 0) && (isnothing(CS) || (CS2 !== CS))
162162
@warn "The `chunk_size` keyword is deprecated. Please use an `ADType` specifier. For now defaulting to using `AutoForwardDiff` with `chunksize=$(CS2)`."
163163
return _bool_to_ADType(Val{true}(), Val{CS2}(), Val{FD}()), Val{CS2}(), Val{FD}()
164164
end
165+
165166
_CS = CS === nothing ? 0 : CS
166167
return ad_alg, Val{_CS}(), Val{FD}()
167168
end
@@ -186,3 +187,12 @@ function _process_AD_choice(
186187
end
187188
return ad_alg, Val{CS}(), ad_alg.fdtype
188189
end
190+
191+
function _process_AD_choice(ad_alg::AutoSparse, cs2::Val{CS2}, fd::Val{FD}) where {CS2, FD}
192+
_, cs, fd = _process_AD_choice(ad_alg.dense_ad, cs2, fd)
193+
ad_alg, cs, fd
194+
end
195+
196+
function _process_AD_choice(ad_alg, cs2, fd)
197+
ad_alg, cs2, fd
198+
end

lib/OrdinaryDiffEqDifferentiation/Project.toml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ version = "1.4.0"
66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
9+
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
10+
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
911
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
12+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1013
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
1114
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1215
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
@@ -15,16 +18,18 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1518
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
1619
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
1720
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
21+
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
1822
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
19-
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
23+
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
2024
StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718"
2125
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2226

2327
[compat]
24-
ADTypes = "1.11"
28+
ADTypes = "1.14"
2529
ArrayInterface = "7"
2630
DiffEqBase = "6"
2731
DiffEqDevTools = "2.44.4"
32+
DifferentiationInterface = "0.6.48"
2833
FastBroadcast = "0.3"
2934
FiniteDiff = "2"
3035
ForwardDiff = "0.10"
@@ -36,7 +41,6 @@ Random = "<0.0.1, 1"
3641
SafeTestsets = "0.1.0"
3742
SciMLBase = "2"
3843
SparseArrays = "1"
39-
SparseDiffTools = "2"
4044
StaticArrayInterface = "1"
4145
StaticArrays = "1"
4246
Test = "<0.0.1, 1"

lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
module OrdinaryDiffEqDifferentiation
22

3-
import ADTypes: AutoFiniteDiff, AutoForwardDiff, AbstractADType
4-
5-
import SparseDiffTools: SparseDiffTools, matrix_colors, forwarddiff_color_jacobian!,
6-
forwarddiff_color_jacobian, ForwardColorJacCache,
7-
default_chunk_size, getsize, JacVec
3+
import ADTypes
4+
import ADTypes: AutoFiniteDiff, AutoForwardDiff, AbstractADType, AutoSparse
85

96
import ForwardDiff, FiniteDiff
107
import ForwardDiff.Dual
@@ -16,7 +13,7 @@ using DiffEqBase
1613
import LinearAlgebra
1714
import LinearAlgebra: Diagonal, I, UniformScaling, diagind, mul!, lmul!, axpby!, opnorm, lu
1815
import LinearAlgebra: LowerTriangular, UpperTriangular
19-
import SparseArrays: SparseMatrixCSC, AbstractSparseMatrix, nonzeros
16+
import SparseArrays: SparseMatrixCSC, AbstractSparseMatrix, nonzeros, sparse
2017
import ArrayInterface
2118

2219
import StaticArrayInterface
@@ -27,7 +24,9 @@ import StaticArrays: SArray, MVector, SVector, @SVector, StaticArray, MMatrix, S
2724
using DiffEqBase: TimeGradientWrapper,
2825
UJacobianWrapper, TimeDerivativeWrapper,
2926
UDerivativeWrapper
30-
using SciMLBase: AbstractSciMLOperator, constructorof
27+
using SciMLBase: AbstractSciMLOperator, constructorof, @set
28+
using SciMLOperators
29+
import SparseMatrixColorings
3130
import OrdinaryDiffEqCore
3231
using OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm, OrdinaryDiffEqAdaptiveImplicitAlgorithm,
3332
DAEAlgorithm,
@@ -44,11 +43,16 @@ using OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm, OrdinaryDiffEqAdaptiveImplici
4443
FastConvergence, Convergence, SlowConvergence,
4544
VerySlowConvergence, Divergence, NLStatus, MethodType, constvalue
4645

47-
import OrdinaryDiffEqCore: get_chunksize, resize_J_W!, resize_nlsolver!, alg_autodiff,
48-
_get_fwd_tag
46+
import OrdinaryDiffEqCore: get_chunksize, resize_J_W!, resize_nlsolver!, alg_autodiff, _get_fwd_tag
47+
48+
using ConstructionBase
49+
50+
import DifferentiationInterface as DI
4951

5052
using FastBroadcast: @..
5153

54+
using ConcreteStructs: @concrete
55+
5256
@static if isdefined(DiffEqBase, :OrdinaryDiffEqTag)
5357
import DiffEqBase: OrdinaryDiffEqTag
5458
else
@@ -59,5 +63,6 @@ include("alg_utils.jl")
5963
include("linsolve_utils.jl")
6064
include("derivative_utils.jl")
6165
include("derivative_wrappers.jl")
66+
include("operators.jl")
6267

6368
end

0 commit comments

Comments
 (0)