Skip to content

Commit 8f00979

Browse files
committed
refactor: centralize autodiff selection
1 parent 621c1b4 commit 8f00979

21 files changed

+247
-301
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "4.0.0"
66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
9+
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
910
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
1011
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
1112
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
@@ -65,6 +66,7 @@ ArrayInterface = "7.16"
6566
BandedMatrices = "1.5"
6667
BenchmarkTools = "1.4"
6768
CUDA = "5.5"
69+
CommonSolve = "0.2.4"
6870
ConcreteStructs = "0.2.3"
6971
DiffEqBase = "6.158.3"
7072
DifferentiationInterface = "0.6.1"

docs/src/native/solvers.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,16 @@ documentation.
2222
preconditioners. For more information on specifying preconditioners for LinearSolve
2323
algorithms, consult the
2424
[LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
25-
- `linesearch`: the line search algorithm to use. Defaults to [`NoLineSearch()`](@extref LineSearch.NoLineSearch),
26-
which means that no line search is performed.
27-
- `autodiff`/`jacobian_ad`: etermines the backend used for the Jacobian. Note that this
25+
- `linesearch`: the line search algorithm to use. Defaults to
26+
[`NoLineSearch()`](@extref LineSearch.NoLineSearch), which means that no line search is
27+
performed.
28+
- `autodiff`: etermines the backend used for the Jacobian. Note that this
2829
argument is ignored if an analytical Jacobian is passed, as that will be used instead.
2930
Defaults to `nothing` which means that a default is selected according to the problem
3031
specification! Valid choices are types from ADTypes.jl.
31-
- `forward_ad`/`vjp_autodiff`: similar to `autodiff`, but is used to compute Jacobian
32+
- `vjp_autodiff`: similar to `autodiff`, but is used to compute Jacobian
3233
Vector Products. Ignored if the NonlinearFunction contains the `jvp` function.
33-
- `reverse_ad`/`vjp_autodiff`: similar to `autodiff`, but is used to compute Vector
34+
- `vjp_autodiff`: similar to `autodiff`, but is used to compute Vector
3435
Jacobian Products. Ignored if the NonlinearFunction contains the `vjp` function.
3536
- `concrete_jac`: whether to build a concrete Jacobian. If a Krylov-subspace method is
3637
used, then the Jacobian will not be constructed and instead direct Jacobian-Vector

docs/src/release_notes.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
- Use of termination conditions from `DiffEqBase` has been removed. Use the termination
1616
conditions from `NonlinearSolveBase` instead.
1717
- If no autodiff is provided, we now choose from a list of autodiffs based on the packages
18-
loaded. For example, if `Enzyme` is loaded, we will default to that. In general, we
19-
don't guarantee the exact autodiff selected if `autodiff` is not provided (i.e.
20-
`nothing`).
18+
loaded. For example, if `Enzyme` is loaded, we will default to that (for reverse mode).
19+
In general, we don't guarantee the exact autodiff selected if `autodiff` is not provided
20+
(i.e. `nothing`).
2121

2222
## Dec '23
2323

lib/NonlinearSolveBase/src/autodiff.jl

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,34 @@
33

44
# Ordering is important here. We want to select the first one that is compatible with the
55
# problem.
6-
const ReverseADs = (
7-
ADTypes.AutoEnzyme(; mode = EnzymeCore.Reverse),
8-
ADTypes.AutoZygote(),
9-
ADTypes.AutoTracker(),
10-
ADTypes.AutoReverseDiff(; compile = true),
11-
ADTypes.AutoReverseDiff(),
12-
ADTypes.AutoFiniteDiff()
13-
)
6+
# XXX: Remove this once Enzyme is properly supported on Julia 1.11+
7+
@static if VERSION v"1.11-"
8+
const ReverseADs = (
9+
ADTypes.AutoZygote(),
10+
ADTypes.AutoTracker(),
11+
ADTypes.AutoReverseDiff(; compile = true),
12+
ADTypes.AutoReverseDiff(),
13+
ADTypes.AutoEnzyme(; mode = EnzymeCore.Reverse),
14+
ADTypes.AutoFiniteDiff()
15+
)
16+
else
17+
const ReverseADs = (
18+
ADTypes.AutoEnzyme(; mode = EnzymeCore.Reverse),
19+
ADTypes.AutoZygote(),
20+
ADTypes.AutoTracker(),
21+
ADTypes.AutoReverseDiff(; compile = true),
22+
ADTypes.AutoReverseDiff(),
23+
ADTypes.AutoFiniteDiff()
24+
)
25+
end
1426

1527
const ForwardADs = (
16-
ADTypes.AutoEnzyme(; mode = EnzymeCore.Forward),
1728
ADTypes.AutoPolyesterForwardDiff(),
1829
ADTypes.AutoForwardDiff(),
30+
ADTypes.AutoEnzyme(; mode = EnzymeCore.Forward),
1931
ADTypes.AutoFiniteDiff()
2032
)
2133

22-
# TODO: Handle Sparsity
23-
2434
function select_forward_mode_autodiff(
2535
prob::AbstractNonlinearProblem, ad::AbstractADType; warn_check_mode::Bool = true)
2636
if warn_check_mode && !(ADTypes.mode(ad) isa ADTypes.ForwardMode)

lib/SimpleNonlinearSolve/src/utils.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ function prepare_jacobian(prob, autodiff, _, x::Number)
130130
if SciMLBase.has_jac(prob.f) || SciMLBase.has_vjp(prob.f) || SciMLBase.has_jvp(prob.f)
131131
return AnalyticJacobian()
132132
end
133-
# return DI.prepare_derivative(prob.f, autodiff, x, Constant(prob.p))
134133
return DINoPreparation()
135134
end
136135
function prepare_jacobian(prob, autodiff, fx, x)

src/NonlinearSolve.jl

Lines changed: 68 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using PrecompileTools: @compile_workload, @setup_workload
55

66
using ArrayInterface: ArrayInterface, can_setindex, restructure, fast_scalar_indexing,
77
ismutable
8+
using CommonSolve: solve, init, solve!
89
using ConcreteStructs: @concrete
910
using DiffEqBase: DiffEqBase # Needed for `init` / `solve` dispatches
1011
using FastClosures: @closure
@@ -21,13 +22,18 @@ using NonlinearSolveBase: NonlinearSolveBase, nonlinearsolve_forwarddiff_solve,
2122
nonlinearsolve_dual_solution, nonlinearsolve_∂f_∂p,
2223
nonlinearsolve_∂f_∂u, L2_NORM, AbsNormTerminationMode,
2324
AbstractNonlinearTerminationMode,
24-
AbstractSafeBestNonlinearTerminationMode
25+
AbstractSafeBestNonlinearTerminationMode,
26+
select_forward_mode_autodiff, select_reverse_mode_autodiff,
27+
select_jacobian_autodiff
2528
using Printf: @printf
2629
using Preferences: Preferences, @load_preference, @set_preferences!
2730
using RecursiveArrayTools: recursivecopy!
28-
using SciMLBase: AbstractNonlinearAlgorithm, AbstractNonlinearProblem, _unwrap_val,
29-
isinplace, NLStats
31+
using SciMLBase: SciMLBase, AbstractNonlinearAlgorithm, AbstractNonlinearProblem,
32+
_unwrap_val, isinplace, NLStats, NonlinearFunction,
33+
NonlinearLeastSquaresProblem, NonlinearProblem, ReturnCode, get_du, step!,
34+
set_u!, LinearProblem, IdentityOperator
3035
using SciMLOperators: AbstractSciMLOperator
36+
using SimpleNonlinearSolve: SimpleNonlinearSolve
3137
using StaticArraysCore: StaticArray, SVector, SArray, MArray, Size, SMatrix
3238
using SymbolicIndexingInterface: SymbolicIndexingInterface, ParameterIndexingProxy,
3339
symbolic_container, parameter_values, state_values, getu,
@@ -95,65 +101,65 @@ include("internal/forward_diff.jl") # we need to define after the algorithms
95101
include("utils.jl")
96102
include("default.jl")
97103

98-
@setup_workload begin
99-
nlfuncs = ((NonlinearFunction{false}((u, p) -> u .* u .- p), 0.1),
100-
(NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p), [0.1]))
101-
probs_nls = NonlinearProblem[]
102-
for (fn, u0) in nlfuncs
103-
push!(probs_nls, NonlinearProblem(fn, u0, 2.0))
104-
end
105-
106-
nls_algs = (
107-
NewtonRaphson(),
108-
TrustRegion(),
109-
LevenbergMarquardt(),
110-
Broyden(),
111-
Klement(),
112-
nothing
113-
)
114-
115-
probs_nlls = NonlinearLeastSquaresProblem[]
116-
nlfuncs = (
117-
(NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), [0.1, 0.0]),
118-
(NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)), [0.1, 0.1]),
119-
(
120-
NonlinearFunction{true}(
121-
(du, u, p) -> du[1] = u[1] * u[1] - p, resid_prototype = zeros(1)),
122-
[0.1, 0.0]),
123-
(
124-
NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p),
125-
resid_prototype = zeros(4)),
126-
[0.1, 0.1]
127-
)
128-
)
129-
for (fn, u0) in nlfuncs
130-
push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0))
131-
end
132-
133-
nlls_algs = (
134-
LevenbergMarquardt(),
135-
GaussNewton(),
136-
TrustRegion(),
137-
nothing
138-
)
139-
140-
@compile_workload begin
141-
@sync begin
142-
for T in (Float32, Float64), (fn, u0) in nlfuncs
143-
Threads.@spawn NonlinearProblem(fn, T.(u0), T(2))
144-
end
145-
for (fn, u0) in nlfuncs
146-
Threads.@spawn NonlinearLeastSquaresProblem(fn, u0, 2.0)
147-
end
148-
for prob in probs_nls, alg in nls_algs
149-
Threads.@spawn solve(prob, alg; abstol = 1e-2, verbose = false)
150-
end
151-
for prob in probs_nlls, alg in nlls_algs
152-
Threads.@spawn solve(prob, alg; abstol = 1e-2, verbose = false)
153-
end
154-
end
155-
end
156-
end
104+
# @setup_workload begin
105+
# nlfuncs = ((NonlinearFunction{false}((u, p) -> u .* u .- p), 0.1),
106+
# (NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p), [0.1]))
107+
# probs_nls = NonlinearProblem[]
108+
# for (fn, u0) in nlfuncs
109+
# push!(probs_nls, NonlinearProblem(fn, u0, 2.0))
110+
# end
111+
112+
# nls_algs = (
113+
# NewtonRaphson(),
114+
# TrustRegion(),
115+
# LevenbergMarquardt(),
116+
# Broyden(),
117+
# Klement(),
118+
# nothing
119+
# )
120+
121+
# probs_nlls = NonlinearLeastSquaresProblem[]
122+
# nlfuncs = (
123+
# (NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), [0.1, 0.0]),
124+
# (NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)), [0.1, 0.1]),
125+
# (
126+
# NonlinearFunction{true}(
127+
# (du, u, p) -> du[1] = u[1] * u[1] - p, resid_prototype = zeros(1)),
128+
# [0.1, 0.0]),
129+
# (
130+
# NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p),
131+
# resid_prototype = zeros(4)),
132+
# [0.1, 0.1]
133+
# )
134+
# )
135+
# for (fn, u0) in nlfuncs
136+
# push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0))
137+
# end
138+
139+
# nlls_algs = (
140+
# LevenbergMarquardt(),
141+
# GaussNewton(),
142+
# TrustRegion(),
143+
# nothing
144+
# )
145+
146+
# @compile_workload begin
147+
# @sync begin
148+
# for T in (Float32, Float64), (fn, u0) in nlfuncs
149+
# Threads.@spawn NonlinearProblem(fn, T.(u0), T(2))
150+
# end
151+
# for (fn, u0) in nlfuncs
152+
# Threads.@spawn NonlinearLeastSquaresProblem(fn, u0, 2.0)
153+
# end
154+
# for prob in probs_nls, alg in nls_algs
155+
# Threads.@spawn solve(prob, alg; abstol = 1e-2, verbose = false)
156+
# end
157+
# for prob in probs_nlls, alg in nlls_algs
158+
# Threads.@spawn solve(prob, alg; abstol = 1e-2, verbose = false)
159+
# end
160+
# end
161+
# end
162+
# end
157163

158164
# Rexexports
159165
@reexport using SciMLBase, SimpleNonlinearSolve, NonlinearSolveBase

src/algorithms/broyden.jl

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Broyden(; max_resets::Int = 100, linesearch = NoLineSearch(), reset_tolerance = nothing,
2+
Broyden(; max_resets::Int = 100, linesearch = nothing, reset_tolerance = nothing,
33
init_jacobian::Val = Val(:identity), autodiff = nothing, alpha = nothing)
44
55
An implementation of `Broyden`'s Method [broyden1965class](@cite) with resetting and line
@@ -29,36 +29,37 @@ search.
2929
problem
3030
"""
3131
function Broyden(;
32-
max_resets = 100, linesearch = NoLineSearch(), reset_tolerance = nothing,
33-
init_jacobian::Val{IJ} = Val(:identity), autodiff = nothing,
34-
alpha = nothing, update_rule::Val{UR} = Val(:good_broyden)) where {IJ, UR}
35-
if IJ === :identity
36-
if UR === :diagonal
37-
initialization = IdentityInitialization(alpha, DiagonalStructure())
38-
else
39-
initialization = IdentityInitialization(alpha, FullStructure())
40-
end
41-
elseif IJ === :true_jacobian
42-
initialization = TrueJacobianInitialization(FullStructure(), autodiff)
43-
else
44-
throw(ArgumentError("`init_jacobian` must be one of `:identity` or \
45-
`:true_jacobian`"))
46-
end
32+
max_resets = 100, linesearch = nothing, reset_tolerance = nothing,
33+
init_jacobian = Val(:identity), autodiff = nothing, alpha = nothing,
34+
update_rule = Val(:good_broyden))
35+
initialization = broyden_init(init_jacobian, update_rule, autodiff, alpha)
36+
update_rule = broyden_update_rule(update_rule)
37+
return ApproximateJacobianSolveAlgorithm{
38+
init_jacobian isa Val{:true_jacobian}, :Broyden}(;
39+
linesearch, descent = NewtonDescent(), update_rule, max_resets, initialization,
40+
reinit_rule = NoChangeInStateReset(; reset_tolerance))
41+
end
4742

48-
update_rule = if UR === :good_broyden
49-
GoodBroydenUpdateRule()
50-
elseif UR === :bad_broyden
51-
BadBroydenUpdateRule()
52-
elseif UR === :diagonal
53-
GoodBroydenUpdateRule()
54-
else
55-
throw(ArgumentError("`update_rule` must be one of `:good_broyden`, `:bad_broyden`, \
56-
or `:diagonal`"))
57-
end
43+
function broyden_init(::Val{:identity}, ::Val{:diagonal}, autodiff, alpha)
44+
return IdentityInitialization(alpha, DiagonalStructure())
45+
end
46+
function broyden_init(::Val{:identity}, ::Val, autodiff, alpha)
47+
IdentityInitialization(alpha, FullStructure())
48+
end
49+
function broyden_init(::Val{:true_jacobian}, ::Val, autodiff, alpha)
50+
return TrueJacobianInitialization(FullStructure(), autodiff)
51+
end
52+
function broyden_init(::Val{IJ}, ::Val{UR}, autodiff, alpha) where {IJ, UR}
53+
error("Unknown combination of `init_jacobian = Val($(Meta.quot(IJ)))` and \
54+
`update_rule = Val($(Meta.quot(UR)))`. Please choose a valid combination.")
55+
end
5856

59-
return ApproximateJacobianSolveAlgorithm{IJ === :true_jacobian, :Broyden}(;
60-
linesearch, descent = NewtonDescent(), update_rule, max_resets,
61-
initialization, reinit_rule = NoChangeInStateReset(; reset_tolerance))
57+
broyden_update_rule(::Val{:good_broyden}) = GoodBroydenUpdateRule()
58+
broyden_update_rule(::Val{:bad_broyden}) = BadBroydenUpdateRule()
59+
broyden_update_rule(::Val{:diagonal}) = GoodBroydenUpdateRule()
60+
function broyden_update_rule(::Val{UR}) where {UR}
61+
error("Unknown update rule `update_rule = Val($(Meta.quot(UR)))`. Please choose a \
62+
valid update rule.")
6263
end
6364

6465
# Checks for no significant change for `nsteps`

src/algorithms/dfsane.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# XXX: remove kwargs with unicode
12
"""
23
DFSane(; σ_min = 1 // 10^10, σ_max = 1e10, σ_1 = 1, M::Int = 10, γ = 1 // 10^4,
34
τ_min = 1 // 10, τ_max = 1 // 2, n_exp::Int = 2, max_inner_iterations::Int = 100,

src/algorithms/gauss_newton.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
"""
2-
GaussNewton(; concrete_jac = nothing, linsolve = nothing, linesearch = NoLineSearch(),
3-
precs = DEFAULT_PRECS, adkwargs...)
2+
GaussNewton(; concrete_jac = nothing, linsolve = nothing, precs = DEFAULT_PRECS,
3+
linesearch = nothing, vjp_autodiff = nothing, autodiff = nothing,
4+
jvp_autodiff = nothing)
45
56
An advanced GaussNewton implementation with support for efficient handling of sparse
67
matrices via colored automatic differentiation and preconditioned linear solvers. Designed
78
for large-scale and numerically-difficult nonlinear least squares problems.
89
"""
910
function GaussNewton(; concrete_jac = nothing, linsolve = nothing, precs = DEFAULT_PRECS,
10-
linesearch = NoLineSearch(), vjp_autodiff = nothing, autodiff = nothing)
11-
descent = NewtonDescent(; linsolve, precs)
12-
return GeneralizedFirstOrderAlgorithm(; concrete_jac, name = :GaussNewton, descent,
13-
jacobian_ad = autodiff, reverse_ad = vjp_autodiff, linesearch)
11+
linesearch = nothing, vjp_autodiff = nothing, autodiff = nothing,
12+
jvp_autodiff = nothing)
13+
return GeneralizedFirstOrderAlgorithm{concrete_jac, :GaussNewton}(; linesearch,
14+
descent = NewtonDescent(; linsolve, precs), autodiff, vjp_autodiff, jvp_autodiff)
1415
end

0 commit comments

Comments
 (0)