Skip to content

Commit 631c4d6

Browse files
authored
Merge branch 'master' into os/use-LinearSolve-precs
2 parents c4e6def + d6d741b commit 631c4d6

File tree

14 files changed

+252
-166
lines changed

14 files changed

+252
-166
lines changed

Project.toml

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
99
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
1010
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
11+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1112
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
1213
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
1314
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
@@ -58,20 +59,21 @@ NonlinearSolveSymbolicsExt = "Symbolics"
5859
NonlinearSolveZygoteExt = "Zygote"
5960

6061
[compat]
61-
ADTypes = "1.1.0"
62+
ADTypes = "1.9"
6263
Aqua = "0.8"
63-
ArrayInterface = "7.9"
64+
ArrayInterface = "7.16"
6465
BandedMatrices = "1.5"
6566
BenchmarkTools = "1.4"
66-
CUDA = "5.2"
67+
CUDA = "5.5"
6768
ConcreteStructs = "0.2.3"
68-
DiffEqBase = "6.149.0"
69-
Enzyme = "0.12"
69+
DiffEqBase = "6.155.3"
70+
DifferentiationInterface = "0.6.1"
71+
Enzyme = "0.13.2"
7072
ExplicitImports = "1.5"
71-
FastBroadcast = "0.2.8, 0.3"
73+
FastBroadcast = "0.3.5"
7274
FastClosures = "0.3.2"
7375
FastLevenbergMarquardt = "0.1"
74-
FiniteDiff = "2.22"
76+
FiniteDiff = "2.24"
7577
FixedPointAcceleration = "0.3"
7678
ForwardDiff = "0.10.36"
7779
Hwloc = "3"
@@ -80,36 +82,36 @@ LazyArrays = "1.8.2, 2"
8082
LeastSquaresOptim = "0.8.5"
8183
LineSearches = "7.2"
8284
LinearAlgebra = "1.10"
83-
LinearSolve = "2.30"
85+
LinearSolve = "2.35"
8486
MINPACK = "1.2"
8587
MaybeInplace = "0.1.3"
86-
ModelingToolkit = "9.15.0"
88+
ModelingToolkit = "9.41.0"
8789
NLSolvers = "0.5"
8890
NLsolve = "4.5"
8991
NaNMath = "1"
9092
NonlinearProblemLibrary = "0.1.2"
91-
OrdinaryDiffEq = "6.75"
93+
OrdinaryDiffEqTsit5 = "1.1.0"
9294
Pkg = "1.10"
9395
PrecompileTools = "1.2"
9496
Preferences = "1.4"
9597
Printf = "1.10"
9698
Random = "1.91"
9799
ReTestItems = "1.24"
98-
RecursiveArrayTools = "3.8"
100+
RecursiveArrayTools = "3.27"
99101
Reexport = "1.2"
100102
SIAMFANLEquations = "1.0.1"
101-
SciMLBase = "2.34.0"
103+
SciMLBase = "2.54.0"
102104
SciMLJacobianOperators = "0.1"
103-
SimpleNonlinearSolve = "1.8"
105+
SimpleNonlinearSolve = "1.12.3"
104106
SparseArrays = "1.10"
105-
SparseDiffTools = "2.19"
107+
SparseDiffTools = "2.22"
106108
SpeedMapping = "0.3"
107109
StableRNGs = "1"
108110
StaticArrays = "1.9"
109111
StaticArraysCore = "1.4"
110112
Sundials = "4.23.1"
111-
SymbolicIndexingInterface = "0.3.15"
112-
Symbolics = "5.26, 6"
113+
SymbolicIndexingInterface = "0.3.31"
114+
Symbolics = "6.12"
113115
Test = "1.10"
114116
TimerOutputs = "0.5.23"
115117
Zygote = "0.6.69"
@@ -133,7 +135,7 @@ NLSolvers = "337daf1e-9722-11e9-073e-8b9effe078ba"
133135
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
134136
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
135137
NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141"
136-
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
138+
OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a"
137139
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
138140
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
139141
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
@@ -147,4 +149,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
147149
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
148150

149151
[targets]
150-
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "CUDA", "Enzyme", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "MINPACK", "ModelingToolkit", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEq", "Pkg", "Random", "ReTestItems", "SIAMFANLEquations", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Symbolics", "Test", "Zygote"]
152+
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "CUDA", "Enzyme", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "MINPACK", "ModelingToolkit", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "Pkg", "Random", "ReTestItems", "SIAMFANLEquations", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Symbolics", "Test", "Zygote"]

docs/Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
1010
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
1111
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
1212
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
13-
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
13+
OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a"
1414
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
1515
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1616
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
@@ -30,10 +30,11 @@ DiffEqBase = "6.136"
3030
Documenter = "1"
3131
DocumenterCitations = "1"
3232
IncompleteLU = "0.2"
33+
InteractiveUtils = "<0.0.1, 1"
3334
LinearSolve = "2"
3435
ModelingToolkit = "8, 9"
3536
NonlinearSolve = "3"
36-
OrdinaryDiffEq = "6"
37+
OrdinaryDiffEqTsit5 = "1.1.0"
3738
Plots = "1"
3839
Random = "<0.0.1, 1"
3940
SciMLBase = "2.4"

docs/src/tutorials/code_optimization.md

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@ Take for example a prototypical small nonlinear solver code in its out-of-place
3333
```@example small_opt
3434
using NonlinearSolve
3535
36-
function f(u, p)
37-
u .* u .- p
38-
end
36+
f(u, p) = u .* u .- p
3937
u0 = [1.0, 1.0]
4038
p = 2.0
4139
prob = NonlinearProblem(f, u0, p)
@@ -53,9 +51,7 @@ using BenchmarkTools
5351
Note that this way of writing the function is a shorthand for:
5452

5553
```@example small_opt
56-
function f(u, p)
57-
[u[1] * u[1] - p, u[2] * u[2] - p]
58-
end
54+
f(u, p) = [u[1] * u[1] - p, u[2] * u[2] - p]
5955
```
6056

6157
where the function `f` returns an array. This is a common pattern from things like MATLAB's
@@ -71,7 +67,7 @@ by hand, this looks like:
7167
function f(du, u, p)
7268
du[1] = u[1] * u[1] - p
7369
du[2] = u[2] * u[2] - p
74-
nothing
70+
return nothing
7571
end
7672
7773
prob = NonlinearProblem(f, u0, p)
@@ -84,6 +80,7 @@ the `.=` in-place broadcasting.
8480
```@example small_opt
8581
function f(du, u, p)
8682
du .= u .* u .- p
83+
return nothing
8784
end
8885
8986
@benchmark sol = solve(prob, NewtonRaphson())
@@ -114,6 +111,7 @@ to normal array expressions, for example:
114111

115112
```@example small_opt
116113
using StaticArrays
114+
117115
A = SA[2.0, 3.0, 5.0]
118116
typeof(A)
119117
```
@@ -135,22 +133,20 @@ want to use the out-of-place allocating form, but this time we want to output a
135133
array. Doing it with broadcasting looks like:
136134

137135
```@example small_opt
138-
function f_SA(u, p)
139-
u .* u .- p
140-
end
136+
f_SA(u, p) = u .* u .- p
137+
141138
u0 = SA[1.0, 1.0]
142139
p = 2.0
143140
prob = NonlinearProblem(f_SA, u0, p)
141+
144142
@benchmark solve(prob, NewtonRaphson())
145143
```
146144

147145
Note that only change here is that `u0` is made into a StaticArray! If we needed to write
148146
`f` out for a more complex nonlinear case, then we'd simply do the following:
149147

150148
```@example small_opt
151-
function f_SA(u, p)
152-
SA[u[1] * u[1] - p, u[2] * u[2] - p]
153-
end
149+
f_SA(u, p) = SA[u[1] * u[1] - p, u[2] * u[2] - p]
154150
155151
@benchmark solve(prob, NewtonRaphson())
156152
```

docs/src/tutorials/optimizing_parameterized_ode.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Let us fit a parameterized ODE to some data. We will use the Lotka-Volterra mode
44
example. We will use Single Shooting to fit the parameters.
55

66
```@example parameterized_ode
7-
using OrdinaryDiffEq, NonlinearSolve, Plots
7+
using OrdinaryDiffEqTsit5, NonlinearSolve, Plots
88
```
99

1010
Let us simulate some real data from the Lotka-Volterra model.

lib/SciMLJacobianOperators/Project.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
99
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1010
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
11-
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
1211
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
1312
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1413
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
@@ -19,9 +18,8 @@ ADTypes = "1.8.1"
1918
Aqua = "0.8.7"
2019
ConcreteStructs = "0.2.3"
2120
ConstructionBase = "1.5"
22-
DifferentiationInterface = "0.5"
21+
DifferentiationInterface = "0.6.1"
2322
Enzyme = "0.12, 0.13"
24-
EnzymeCore = "0.7, 0.8"
2523
ExplicitImports = "1.9.0"
2624
FastClosures = "0.3.2"
2725
FiniteDiff = "2.24.0"

lib/SciMLJacobianOperators/README.md

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# SciMLJacobianOperators.jl
2+
3+
SciMLJacobianOperators provides a convenient way to compute Jacobian-Vector Product (JVP)
4+
and Vector-Jacobian Product (VJP) using
5+
[SciMLOperators.jl](https://github.com/SciML/SciMLOperators.jl) and
6+
[DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl).
7+
8+
Currently we have interfaces for:
9+
10+
- `NonlinearProblem`
11+
- `NonlinearLeastSquaresProblem`
12+
13+
and all autodiff backends supported by DifferentiationInterface.jl are supported.
14+
15+
## Example
16+
17+
```julia
18+
using SciMLJacobianOperators, NonlinearSolve, Enzyme, ForwardDiff
19+
20+
# Define the problem
21+
f(u, p) = u .* u .- p
22+
u0 = ones(4)
23+
p = 2.0
24+
prob = NonlinearProblem(f, u0, p)
25+
fu0 = f(u0, p)
26+
v = ones(4) .* 2
27+
28+
# Construct the operator
29+
jac_op = JacobianOperator(
30+
prob, fu0, u0;
31+
jvp_autodiff = AutoForwardDiff(),
32+
vjp_autodiff = AutoEnzyme(; mode = Enzyme.Reverse)
33+
)
34+
sjac_op = StatefulJacobianOperator(jac_op, u0, p)
35+
36+
sjac_op * v # Computes the JVP
37+
# 4-element Vector{Float64}:
38+
# 4.0
39+
# 4.0
40+
# 4.0
41+
# 4.0
42+
43+
sjac_op' * v # Computes the VJP
44+
# 4-element Vector{Float64}:
45+
# 4.0
46+
# 4.0
47+
# 4.0
48+
# 4.0
49+
50+
# What if we multiply the VJP and JVP?
51+
snormal_form = sjac_op' * sjac_op
52+
53+
snormal_form * v # Computes JᵀJ * v
54+
# 4-element Vector{Float64}:
55+
# 8.0
56+
# 8.0
57+
# 8.0
58+
# 8.0
59+
```

lib/SciMLJacobianOperators/src/SciMLJacobianOperators.jl

Lines changed: 22 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
module SciMLJacobianOperators
22

3-
using ADTypes: ADTypes, AutoSparse, AutoEnzyme
3+
using ADTypes: ADTypes, AutoSparse
44
using ConcreteStructs: @concrete
55
using ConstructionBase: ConstructionBase
6-
using DifferentiationInterface: DifferentiationInterface
7-
using EnzymeCore: EnzymeCore
6+
using DifferentiationInterface: DifferentiationInterface, Constant
87
using FastClosures: @closure
98
using LinearAlgebra: LinearAlgebra
109
using SciMLBase: SciMLBase, AbstractNonlinearProblem, AbstractNonlinearFunction
@@ -112,10 +111,10 @@ function JacobianOperator(prob::AbstractNonlinearProblem, fu, u; jvp_autodiff =
112111
iip = SciMLBase.isinplace(prob)
113112
T = promote_type(eltype(u), eltype(fu))
114113

115-
vjp_autodiff = set_function_as_const(get_dense_ad(vjp_autodiff))
114+
vjp_autodiff = get_dense_ad(vjp_autodiff)
116115
vjp_op = prepare_vjp(skip_vjp, prob, f, u, fu; autodiff = vjp_autodiff)
117116

118-
jvp_autodiff = set_function_as_const(get_dense_ad(jvp_autodiff))
117+
jvp_autodiff = get_dense_ad(jvp_autodiff)
119118
jvp_op = prepare_jvp(skip_jvp, prob, f, u, fu; autodiff = jvp_autodiff)
120119

121120
output_cache = fu isa Number ? T(fu) : similar(fu, T)
@@ -295,23 +294,21 @@ function prepare_vjp(::Val{false}, prob::AbstractNonlinearProblem,
295294

296295
@assert autodiff!==nothing "`vjp_autodiff` must be provided if `f` doesn't have \
297296
analytic `vjp` or `jac`."
298-
# TODO: Once DI supports const params we can use `p`
299-
fₚ = SciMLBase.JacobianWrapper{SciMLBase.isinplace(f)}(f, prob.p)
300297
if SciMLBase.isinplace(f)
301-
@assert DI.check_twoarg(autodiff) "Backend: $(autodiff) doesn't support in-place \
302-
problems."
298+
@assert DI.check_inplace(autodiff) "Backend: $(autodiff) doesn't support in-place \
299+
problems."
303300
fu_cache = copy(fu)
304-
v_fake = copy(fu)
305-
di_extras = DI.prepare_pullback(fₚ, fu_cache, autodiff, u, v_fake)
301+
di_extras = DI.prepare_pullback(f, fu_cache, autodiff, u, (fu,), Constant(prob.p))
306302
return @closure (vJ, v, u, p) -> begin
307-
DI.pullback!(fₚ, fu_cache, reshape(vJ, size(u)), autodiff,
308-
u, reshape(v, size(fu_cache)), di_extras)
303+
DI.pullback!(f, fu_cache, (reshape(vJ, size(u)),), di_extras, autodiff,
304+
u, (reshape(v, size(fu_cache)),), Constant(p))
309305
return
310306
end
311307
else
312-
di_extras = DI.prepare_pullback(fₚ, autodiff, u, fu)
308+
di_extras = DI.prepare_pullback(f, autodiff, u, (fu,), Constant(prob.p))
313309
return @closure (v, u, p) -> begin
314-
return DI.pullback(fₚ, autodiff, u, reshape(v, size(fu)), di_extras)
310+
return only(DI.pullback(
311+
f, di_extras, autodiff, u, (reshape(v, size(fu)),), Constant(p)))
315312
end
316313
end
317314
end
@@ -342,23 +339,21 @@ function prepare_jvp(::Val{false}, prob::AbstractNonlinearProblem,
342339

343340
@assert autodiff!==nothing "`jvp_autodiff` must be provided if `f` doesn't have \
344341
analytic `vjp` or `jac`."
345-
# TODO: Once DI supports const params we can use `p`
346-
fₚ = SciMLBase.JacobianWrapper{SciMLBase.isinplace(f)}(f, prob.p)
347342
if SciMLBase.isinplace(f)
348-
@assert DI.check_twoarg(autodiff) "Backend: $(autodiff) doesn't support in-place \
349-
problems."
343+
@assert DI.check_inplace(autodiff) "Backend: $(autodiff) doesn't support in-place \
344+
problems."
350345
fu_cache = copy(fu)
351-
di_extras = DI.prepare_pushforward(fₚ, fu_cache, autodiff, u, u)
346+
di_extras = DI.prepare_pushforward(f, fu_cache, autodiff, u, (u,), Constant(prob.p))
352347
return @closure (Jv, v, u, p) -> begin
353-
DI.pushforward!(
354-
fₚ, fu_cache, reshape(Jv, size(fu_cache)),
355-
autodiff, u, reshape(v, size(u)), di_extras)
348+
DI.pushforward!(f, fu_cache, (reshape(Jv, size(fu_cache)),), di_extras,
349+
autodiff, u, (reshape(v, size(u)),), Constant(p))
356350
return
357351
end
358352
else
359-
di_extras = DI.prepare_pushforward(fₚ, autodiff, u, u)
353+
di_extras = DI.prepare_pushforward(f, autodiff, u, (u,), Constant(prob.p))
360354
return @closure (v, u, p) -> begin
361-
return DI.pushforward(fₚ, autodiff, u, reshape(v, size(u)), di_extras)
355+
return only(DI.pushforward(
356+
f, di_extras, autodiff, u, (reshape(v, size(u)),), Constant(p)))
362357
end
363358
end
364359
end
@@ -371,10 +366,8 @@ function prepare_scalar_op(::Val{false}, prob::AbstractNonlinearProblem,
371366

372367
@assert autodiff!==nothing "`autodiff` must be provided if `f` doesn't have \
373368
analytic `vjp` or `jvp` or `jac`."
374-
# TODO: Once DI supports const params we can use `p`
375-
fₚ = Base.Fix2(f, prob.p)
376-
di_extras = DI.prepare_derivative(fₚ, autodiff, u)
377-
return @closure (v, u, p) -> DI.derivative(fₚ, autodiff, u, di_extras) * v
369+
di_extras = DI.prepare_derivative(f, autodiff, u, Constant(prob.p))
370+
return @closure (v, u, p) -> DI.derivative(f, di_extras, autodiff, u, Constant(p)) * v
378371
end
379372

380373
get_dense_ad(::Nothing) = nothing
@@ -386,12 +379,6 @@ function get_dense_ad(ad::AutoSparse)
386379
return dense_ad
387380
end
388381

389-
# In our case we know that it is safe to mark the function as const
390-
set_function_as_const(ad) = ad
391-
function set_function_as_const(ad::AutoEnzyme{M, Nothing}) where {M}
392-
return AutoEnzyme(; ad.mode, function_annotation = EnzymeCore.Const)
393-
end
394-
395382
export JacobianOperator, VecJacOperator, JacVecOperator
396383
export StatefulJacobianOperator
397384
export StatefulJacobianNormalFormOperator

0 commit comments

Comments
 (0)