Skip to content

Commit b99ea13

Browse files
Merge pull request #283 from ErikQQY/qqy/di
Use DifferentiationInterface for Jacobian stuff
2 parents 7c3246c + 5f1d9ad commit b99ea13

32 files changed

+535
-517
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "BoundaryValueDiffEq"
22
uuid = "764a87c0-6b3e-53db-9096-fe964310641d"
3-
version = "5.15.0"
3+
version = "5.16.0"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -25,7 +25,7 @@ ODEInterface = "54ca160b-1b9f-5127-a996-1867f4bc2a2c"
2525
BoundaryValueDiffEqODEInterfaceExt = "ODEInterface"
2626

2727
[compat]
28-
ADTypes = "1.11"
28+
ADTypes = "1.13"
2929
Aqua = "0.8.9"
3030
ArrayInterface = "7.18"
3131
BoundaryValueDiffEqAscher = "1"

docs/src/basics/autodiff.md

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
!!! note
44

5-
We support all backends supported by DifferentiationInterface.jl. Please refer to
5+
We support ForwardDiff.jl, FiniteDiff.jl and PolyesterForwardDiff.jl(PolyesterForwardDiff only for collocation methods) via DifferentiationInterface.jl. Please refer to
66
the [backends page](https://juliadiff.org/DifferentiationInterface.jl/DifferentiationInterface/stable/explanation/backends/)
77
for more information.
88

@@ -17,8 +17,6 @@ In BoundaryValueDiffEq.jl, we require AD to obtain the Jacobian of the loss func
1717

1818
- [`AutoFiniteDiff`](@extref ADTypes.AutoFiniteDiff): Finite differencing using
1919
`FiniteDiff.jl`, not optimal but always applicable.
20-
- [`AutoFiniteDifferences`](@extref ADTypes.AutoFiniteDifferences): Finite differencing
21-
using `FiniteDifferences.jl`, not optimal but always applicable.
2220

2321
## Summary of Forward Mode AD Backends
2422

@@ -27,10 +25,3 @@ In BoundaryValueDiffEq.jl, we require AD to obtain the Jacobian of the loss func
2725
- [`AutoPolyesterForwardDiff`](@extref ADTypes.AutoPolyesterForwardDiff): Might be faster
2826
than [`AutoForwardDiff`](@extref ADTypes.AutoForwardDiff) for large problems. Requires
2927
`PolyesterForwardDiff.jl` to be installed and loaded.
30-
31-
## Summary of Reverse Mode AD Backends
32-
33-
- [`AutoZygote`](@extref ADTypes.AutoZygote): The fastest choice for non-mutating
34-
array-based (BLAS) functions.
35-
- [`AutoEnzyme`](@extref ADTypes.AutoEnzyme): Uses `Enzyme.jl` Reverse Mode and works for
36-
both in-place and out-of-place functions.

lib/BoundaryValueDiffEqAscher/Project.toml

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,21 @@
11
name = "BoundaryValueDiffEqAscher"
22
uuid = "7227322d-7511-4e07-9247-ad6ff830280e"
33
authors = ["Qingyu Qu <[email protected]>"]
4-
version = "1.4.0"
4+
version = "1.5.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
8-
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
98
AlmostBlockDiagonals = "a95523ee-d6da-40b5-98cc-27bc505739d5"
109
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
1110
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
1211
BoundaryValueDiffEqCore = "56b672f2-a5fe-4263-ab2d-da677488eb3a"
1312
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
1413
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
14+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1515
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
1616
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1717
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1818
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
19-
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
2019
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
2120
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
2221
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
@@ -25,26 +24,24 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2524
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
2625
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2726
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
28-
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
2927

3028
[compat]
3129
ADTypes = "1.11"
32-
Adapt = "4.1.1"
3330
AlmostBlockDiagonals = "0.1.10"
3431
ArrayInterface = "7.18"
3532
BandedMatrices = "1.7.5"
3633
BoundaryValueDiffEqCore = "1"
3734
ConcreteStructs = "0.2.3"
3835
DiffEqBase = "6.158.3"
3936
DiffEqDevTools = "2.44"
37+
DifferentiationInterface = "0.6.42"
4038
FastClosures = "0.3.2"
4139
ForwardDiff = "0.10.38"
4240
Hwloc = "3"
4341
InteractiveUtils = "<0.0.1, 1"
4442
JET = "0.9.12"
4543
LinearAlgebra = "1.10"
4644
LinearSolve = "2.36.2, 3"
47-
Logging = "1.10"
4845
PreallocationTools = "0.4.24"
4946
PrecompileTools = "1.2"
5047
Preferences = "1.4"
@@ -55,7 +52,6 @@ Reexport = "1.2"
5552
SciMLBase = "2.71"
5653
Setfield = "1.1.1"
5754
SparseArrays = "1.10"
58-
SparseDiffTools = "2.23"
5955
StaticArrays = "1.9.8"
6056
Test = "1.10"
6157
julia = "1.10"

lib/BoundaryValueDiffEqAscher/src/BoundaryValueDiffEqAscher.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@ using BoundaryValueDiffEqCore: BVPJacobianAlgorithm, __extract_problem_details,
77
__concrete_nonlinearsolve_algorithm,
88
__internal_nlsolve_problem, BoundaryValueDiffEqAlgorithm,
99
__vec, __vec_f, __vec_f!, __vec_bc, __vec_bc!,
10-
__extract_mesh, get_dense_ad, __sparse_jacobian_cache
10+
__extract_mesh, get_dense_ad, __get_bcresid_prototype
1111
using ConcreteStructs: @concrete
1212
using DiffEqBase: DiffEqBase
13+
using DifferentiationInterface: DifferentiationInterface, Constant, prepare_jacobian
1314
using FastClosures: @closure
1415
using ForwardDiff: ForwardDiff, Dual
1516
using LinearAlgebra
@@ -19,8 +20,8 @@ using Reexport: @reexport
1920
using SciMLBase: SciMLBase, AbstractDiffEqInterpolation, StandardBVProblem, __solve,
2021
_unwrap_val
2122
using Setfield: @set!
22-
using SparseDiffTools: init_jacobian, sparse_jacobian, sparse_jacobian_cache,
23-
sparse_jacobian!, SymbolicsSparsityDetection, NoSparsityDetection
23+
24+
const DI = DifferentiationInterface
2425

2526
@reexport using ADTypes, BoundaryValueDiffEqCore, SciMLBase
2627

lib/BoundaryValueDiffEqAscher/src/ascher.jl

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractAscher; dt = 0.0,
136136
end
137137

138138
if prob.f.bcjac === nothing
139-
bcjac = construct_bc_jac(prob, bcresid_prototype, prob.problem_type)
139+
bcjac = construct_bc_jac(prob)
140140
else
141141
bcjac = prob.f.bcjac
142142
end
@@ -177,7 +177,7 @@ end
177177
function __perform_ascher_iteration(cache::AscherCache{iip, T}, abstol, adaptive::Bool;
178178
nlsolve_kwargs = (;), kwargs...) where {iip, T}
179179
info::ReturnCode.T = ReturnCode.Success
180-
nlprob::NonlinearProblem = __construct_nlproblem(cache)
180+
nlprob = __construct_nlproblem(cache)
181181
nlsolve_alg = __concrete_nonlinearsolve_algorithm(nlprob, cache.alg.nlsolve)
182182
nlsol = __solve(nlprob, nlsolve_alg; abstol, kwargs..., nlsolve_kwargs...)
183183
error_norm = 2 * abstol
@@ -315,31 +315,50 @@ function __construct_nlproblem(cache::AscherCache{iip, T}) where {iip, T}
315315
else
316316
@closure (z, p) -> @views Φ(cache, z, pt)
317317
end
318+
318319
lz = reduce(vcat, cache.z)
319-
sd = alg.jac_alg.diffmode isa AutoSparse ? SymbolicsSparsityDetection() :
320-
NoSparsityDetection()
321-
ad = alg.jac_alg.diffmode
322-
lossₚ = (iip ? __Fix3 : Base.Fix2)(loss, cache.p)
323-
jac_cache = __sparse_jacobian_cache(Val(iip), ad, sd, lossₚ, lz, lz)
324-
jac_prototype = init_jacobian(jac_cache)
320+
resid_prototype = zero(lz)
321+
diffmode = if alg.jac_alg.diffmode isa AutoSparse
322+
#AutoSparse(get_dense_ad(alg.jac_alg.diffmode);
323+
# sparsity_detector = SparseConnectivityTracer.TracerSparsityDetector(),
324+
# coloring_algorithm = GreedyColoringAlgorithm(LargestFirst()))
325+
# Ascher collocation need more generalized collocation to support AutoSparse
326+
get_dense_ad(alg.jac_alg.diffmode)
327+
else
328+
alg.jac_alg.diffmode
329+
end
330+
331+
jac_cache = if iip
332+
DI.prepare_jacobian(loss, resid_prototype, diffmode, lz, Constant(cache.p))
333+
else
334+
DI.prepare_jacobian(loss, diffmode, lz, Constant(cache.p))
335+
end
336+
337+
jac_prototype = if iip
338+
DI.jacobian(loss, resid_prototype, jac_cache, diffmode, lz, Constant(cache.p))
339+
else
340+
DI.jacobian(loss, jac_cache, diffmode, lz, Constant(cache.p))
341+
end
342+
325343
jac = if iip
326-
@closure (J, u, p) -> __ascher_mpoint_jacobian!(J, u, ad, jac_cache, lossₚ, lz)
344+
@closure (J, u, p) -> __ascher_mpoint_jacobian!(
345+
J, u, diffmode, jac_cache, loss, lz, cache.p)
327346
else
328-
@closure (u, p) -> __ascher_mpoint_jacobian(jac_prototype, u, ad, jac_cache, lossₚ)
347+
@closure (u, p) -> __ascher_mpoint_jacobian(
348+
jac_prototype, u, diffmode, jac_cache, loss, cache.p)
329349
end
330-
resid_prototype = zero(lz)
331-
_nlf = NonlinearFunction{iip}(
350+
351+
nlf = NonlinearFunction{iip}(
332352
loss; jac = jac, resid_prototype = resid_prototype, jac_prototype = jac_prototype)
333-
nlprob::NonlinearProblem = NonlinearProblem(_nlf, lz, cache.p)
334-
return nlprob
353+
return __internal_nlsolve_problem(cache.prob, similar(lz), lz, nlf, lz, cache.p)
335354
end
336355

337-
function __ascher_mpoint_jacobian!(J, x, diffmode, diffcache, loss, resid)
338-
sparse_jacobian!(J, diffmode, diffcache, loss, resid, x)
356+
function __ascher_mpoint_jacobian!(J, x, diffmode, diffcache, loss, resid, p)
357+
DI.jacobian!(loss, resid, J, diffcache, diffmode, x, Constant(p))
339358
return nothing
340359
end
341-
function __ascher_mpoint_jacobian(J, x, diffmode, diffcache, loss)
342-
sparse_jacobian!(J, diffmode, diffcache, loss, x)
360+
function __ascher_mpoint_jacobian(J, x, diffmode, diffcache, loss, p)
361+
DI.jacobian!(loss, J, diffcache, diffmode, x, Constant(p))
343362
return J
344363
end
345364

0 commit comments

Comments
 (0)