Skip to content

Commit 8869f80

Browse files
committed
Towards a cleaner and more maintainable internals of NonlinearSolve.jl
1 parent bf3a132 commit 8869f80

File tree

9 files changed

+836
-1207
lines changed

9 files changed

+836
-1207
lines changed

.JuliaFormatter.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
style = "sciml"
2-
format_markdown = true
2+
format_markdown = true
3+
annotate_untyped_fields_with_any = false

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "1.10.0"
4+
version = "1.11.0"
55

66
[deps]
7+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
78
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
9+
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
810
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
911
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
1012
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"

src/NonlinearSolve.jl

Lines changed: 44 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,41 @@
11
module NonlinearSolve
2-
if isdefined(Base, :Experimental) &&
3-
isdefined(Base.Experimental, Symbol("@max_methods"))
2+
3+
if isdefined(Base, :Experimental) && isdefined(Base.Experimental, Symbol("@max_methods"))
44
@eval Base.Experimental.@max_methods 1
55
end
6-
using Reexport
7-
using UnPack: @unpack
8-
using FiniteDiff, ForwardDiff
9-
using ForwardDiff: Dual
10-
using LinearAlgebra
11-
using StaticArraysCore
12-
using RecursiveArrayTools
13-
import EnumX
14-
import ArrayInterface
15-
import LinearSolve
16-
using DiffEqBase
17-
using SparseDiffTools
18-
19-
@reexport using SciMLBase
20-
using SciMLBase: NLStats
21-
@reexport using SimpleNonlinearSolve
22-
23-
import SciMLBase: _unwrap_val
24-
25-
abstract type AbstractNonlinearSolveAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end
26-
abstract type AbstractNewtonAlgorithm{CS, AD, FDT, ST, CJ} <:
27-
AbstractNonlinearSolveAlgorithm end
28-
29-
function SciMLBase.__solve(prob::NonlinearProblem,
30-
alg::AbstractNonlinearSolveAlgorithm, args...;
31-
kwargs...)
6+
7+
using DiffEqBase, LinearAlgebra, LinearSolve, SparseDiffTools
8+
import ForwardDiff
9+
10+
import ADTypes: AbstractFiniteDifferencesMode
11+
import ArrayInterface: undefmatrix
12+
import ConcreteStructs: @concrete
13+
import EnumX: @enumx
14+
import ForwardDiff: Dual
15+
import LinearSolve: ComposePreconditioner, InvPreconditioner, needs_concrete_A
16+
import RecursiveArrayTools: AbstractVectorOfArray, recursivecopy!, recursivefill!
17+
import Reexport: @reexport
18+
import SciMLBase: AbstractNonlinearAlgorithm, NLStats, _unwrap_val, has_jac, isinplace
19+
import SparseDiffTools: __init_𝒥
20+
import StaticArraysCore: StaticArray, SVector
21+
import UnPack: @unpack
22+
23+
@reexport using ADTypes, SciMLBase, SimpleNonlinearSolve
24+
25+
const AbstractSparseADType = Union{ADTypes.AbstractSparseFiniteDifferences,
26+
ADTypes.AbstractSparseForwardMode, ADTypes.AbstractSparseReverseMode}
27+
28+
abstract type AbstractNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end
29+
abstract type AbstractNewtonAlgorithm{CJ, AD} <: AbstractNonlinearSolveAlgorithm end
30+
31+
function SciMLBase.__solve(prob::NonlinearProblem, alg::AbstractNonlinearSolveAlgorithm,
32+
args...; kwargs...)
3233
cache = init(prob, alg, args...; kwargs...)
33-
sol = solve!(cache)
34+
return solve!(cache)
3435
end
3536

37+
# FIXME: Scalar Case is Completely Broken
38+
3639
include("utils.jl")
3740
include("raphson.jl")
3841
include("trustRegion.jl")
@@ -44,23 +47,23 @@ import PrecompileTools
4447

4548
PrecompileTools.@compile_workload begin
4649
for T in (Float32, Float64)
47-
prob = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))
50+
# prob = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))
4851

49-
precompile_algs = if VERSION >= v"1.7"
50-
(NewtonRaphson(), TrustRegion(), LevenbergMarquardt())
51-
else
52-
(NewtonRaphson(),)
53-
end
52+
# precompile_algs = if VERSION v"1.7"
53+
# (NewtonRaphson(), TrustRegion(), LevenbergMarquardt())
54+
# else
55+
# (NewtonRaphson(),)
56+
# end
5457

55-
for alg in precompile_algs
56-
solve(prob, alg, abstol = T(1e-2))
57-
end
58+
# for alg in precompile_algs
59+
# solve(prob, alg, abstol = T(1e-2))
60+
# end
5861

5962
prob = NonlinearProblem{true}((du, u, p) -> du[1] = u[1] * u[1] - p[1], T[0.1],
6063
T[2])
61-
for alg in precompile_algs
62-
solve(prob, alg, abstol = T(1e-2))
63-
end
64+
# for alg in precompile_algs
65+
# solve(prob, alg, abstol = T(1e-2))
66+
# end
6467
end
6568
end
6669

src/ad.jl

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,17 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
2323
return sol, partials
2424
end
2525

26-
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector},
27-
iip,
28-
<:Dual{T, V, P}},
29-
alg::AbstractNewtonAlgorithm,
30-
args...; kwargs...) where {iip, T, V, P}
26+
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector}, iip,
27+
<:Dual{T, V, P}}, alg::AbstractNewtonAlgorithm, args...;
28+
kwargs...) where {iip, T, V, P}
3129
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
3230
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
33-
retcode = sol.retcode)
31+
sol.retcode)
3432
end
35-
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector},
36-
iip,
37-
<:AbstractArray{<:Dual{T, V, P}}},
38-
alg::AbstractNewtonAlgorithm,
39-
args...;
33+
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector}, iip,
34+
<:AbstractArray{<:Dual{T, V, P}}}, alg::AbstractNewtonAlgorithm, args...;
4035
kwargs...) where {iip, T, V, P}
4136
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
4237
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
43-
retcode = sol.retcode)
38+
sol.retcode)
4439
end

src/jacobian.jl

Lines changed: 42 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -1,185 +1,78 @@
1-
struct JacobianWrapper{fType, pType}
2-
f::fType
3-
p::pType
1+
@concrete struct JacobianWrapper
2+
f
3+
p
44
end
55

66
(uf::JacobianWrapper)(u) = uf.f(u, uf.p)
77
(uf::JacobianWrapper)(res, u) = uf.f(res, u, uf.p)
88

9-
struct NonlinearSolveTag end
10-
11-
function sparsity_colorvec(f, x)
12-
sparsity = f.sparsity
13-
colorvec = DiffEqBase.has_colorvec(f) ? f.colorvec :
14-
(isnothing(sparsity) ? (1:length(x)) : matrix_colors(sparsity))
15-
sparsity, colorvec
16-
end
17-
18-
function jacobian_finitediff_forward!(J, f, x, jac_config, forwardcache, cache)
19-
(FiniteDiff.finite_difference_jacobian!(J, f, x, jac_config, forwardcache);
20-
maximum(jac_config.colorvec))
21-
end
22-
function jacobian_finitediff!(J, f, x, jac_config, cache)
23-
(FiniteDiff.finite_difference_jacobian!(J, f, x, jac_config);
24-
2 * maximum(jac_config.colorvec))
25-
end
9+
# function sparsity_colorvec(f, x)
10+
# sparsity = f.sparsity
11+
# colorvec = DiffEqBase.has_colorvec(f) ? f.colorvec :
12+
# (isnothing(sparsity) ? (1:length(x)) : matrix_colors(sparsity))
13+
# sparsity, colorvec
14+
# end
2615

2716
# NoOp for Jacobian if it is not a Abstract Array -- For eg, JacVec Operator
28-
jacobian!(J, cache) = J
29-
function jacobian!(J::AbstractMatrix{<:Number}, cache)
30-
f = cache.f
31-
uf = cache.uf
32-
x = cache.u
33-
fx = cache.fu
34-
jac_config = cache.jac_config
35-
alg = cache.alg
36-
37-
if SciMLBase.has_jac(f)
38-
f.jac(J, x, cache.p)
39-
elseif alg_autodiff(alg)
40-
forwarddiff_color_jacobian!(J, uf, x, jac_config)
41-
#cache.destats.nf += 1
17+
jacobian!!(J, _) = J
18+
# `!!` notation is from BangBang.jl since J might be jacobian in case of oop `f.jac`
19+
# and we don't want wasteful `copyto!`
20+
function jacobian!!(J::Union{AbstractMatrix{<:Number}, Nothing}, cache)
21+
@unpack f, uf, u, p, jac_cache, alg, fu2 = cache
22+
iip = isinplace(cache)
23+
if iip
24+
has_jac(f) ? f.jac(J, u, p) : sparse_jacobian!(J, alg.ad, jac_cache, uf, fu2, u)
4225
else
43-
isforward = alg_difftype(alg) === Val{:forward}
44-
if isforward
45-
uf(fx, x)
46-
#cache.destats.nf += 1
47-
tmp = jacobian_finitediff_forward!(J, uf, x, jac_config, fx,
48-
cache)
49-
else # not forward difference
50-
tmp = jacobian_finitediff!(J, uf, x, jac_config, cache)
51-
end
52-
#cache.destats.nf += tmp
26+
return has_jac(f) ? f.jac(u, p) : sparse_jacobian!(J, alg.ad, jac_cache, uf, u)
5327
end
54-
nothing
28+
return nothing
5529
end
5630

57-
function build_jac_and_jac_config(alg, f::F1, uf::F2, du1, u, tmp, du2) where {F1, F2}
31+
# Build Jacobian Caches
32+
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p,
33+
::Val{iip}) where {iip}
34+
uf = JacobianWrapper(f, p)
35+
5836
haslinsolve = hasfield(typeof(alg), :linsolve)
5937

60-
has_analytic_jac = SciMLBase.has_jac(f)
38+
has_analytic_jac = has_jac(f)
6139
linsolve_needs_jac = (concrete_jac(alg) === nothing &&
6240
(!haslinsolve || (haslinsolve && (alg.linsolve === nothing ||
63-
LinearSolve.needs_concrete_A(alg.linsolve)))))
64-
alg_wants_jac = (concrete_jac(alg) !== nothing && concrete_jac(alg))
41+
needs_concrete_A(alg.linsolve)))))
42+
alg_wants_jac = (concrete_jac(alg) === nothing && concrete_jac(alg))
6543

44+
fu = zero(u) # TODO: Use Prototype
6645
if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac)
67-
sparsity, colorvec = sparsity_colorvec(f, u)
68-
69-
if alg_autodiff(alg)
70-
_chunksize = get_chunksize(alg) === Val(0) ? nothing : get_chunksize(alg) # SparseDiffEq uses different convection...
71-
72-
T = if standardtag(alg)
73-
typeof(ForwardDiff.Tag(NonlinearSolveTag(), eltype(u)))
74-
else
75-
typeof(ForwardDiff.Tag(uf, eltype(u)))
76-
end
77-
jac_config = ForwardColorJacCache(uf, u, _chunksize; colorvec, sparsity,
78-
tag = T)
79-
else
80-
if alg_difftype(alg) !== Val{:complex}
81-
jac_config = FiniteDiff.JacobianCache(tmp, du1, du2, alg_difftype(alg);
82-
colorvec, sparsity)
83-
else
84-
jac_config = FiniteDiff.JacobianCache(Complex{eltype(tmp)}.(tmp),
85-
Complex{eltype(du1)}.(du1), nothing, alg_difftype(alg), eltype(u);
86-
colorvec, sparsity)
87-
end
88-
end
46+
# TODO: We need an Upstream Mode to allow using known sparsity and colorvec
47+
# TODO: We can use the jacobian prototype here
48+
sd = typeof(alg.ad) <: AbstractSparseADType ? SymbolicsSparsityDetection() :
49+
NoSparsityDetection()
50+
jac_cache = iip ? sparse_jacobian_cache(alg.ad, sd, uf, fu, u) :
51+
sparse_jacobian_cache(alg.ad, sd, uf, u; fx=fu)
8952
else
90-
jac_config = nothing
53+
jac_cache = nothing
9154
end
9255

9356
J = if !linsolve_needs_jac
9457
# We don't need to construct the Jacobian
95-
JacVec(uf, u; autodiff = alg_autodiff(alg) ? AutoForwardDiff() : AutoFiniteDiff())
58+
JacVec(uf, u; autodiff = alg.ad)
9659
else
97-
if f.jac_prototype === nothing
98-
ArrayInterface.undefmatrix(u)
60+
if has_analytic_jac
61+
iip ? undefmatrix(u) : nothing
9962
else
100-
f.jac_prototype
63+
f.jac_prototype === nothing ? __init_𝒥(jac_cache) : f.jac_prototype
10164
end
10265
end
10366

104-
return J, jac_config
105-
end
106-
107-
# Build Jacobian Caches
108-
function jacobian_caches(alg::Union{NewtonRaphson, LevenbergMarquardt, TrustRegion}, f, u,
109-
p, ::Val{true})
110-
uf = JacobianWrapper(f, p)
111-
112-
du1 = zero(u)
113-
du2 = zero(u)
114-
tmp = zero(u)
115-
J, jac_config = build_jac_and_jac_config(alg, f, uf, du1, u, tmp, du2)
116-
67+
# FIXME: Assumes same sized `u` and `fu` -- Incorrect Assumption for Levenberg
11768
linprob = LinearProblem(J, _vec(zero(u)); u0 = _vec(zero(u)))
69+
11870
weight = similar(u)
11971
recursivefill!(weight, true)
12072

12173
Pl, Pr = wrapprecs(alg.precs(J, nothing, u, p, nothing, nothing, nothing, nothing,
12274
nothing)..., weight)
12375
linsolve = init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr)
12476

125-
uf, linsolve, J, du1, jac_config
126-
end
127-
128-
function get_chunksize(jac_config::ForwardDiff.JacobianConfig{
129-
T,
130-
V,
131-
N,
132-
D,
133-
}) where {T, V, N, D
134-
}
135-
Val(N)
136-
end # don't degrade compile time information to runtime information
137-
138-
function jacobian_finitediff(f, x, ::Type{diff_type}, dir, colorvec, sparsity,
139-
jac_prototype) where {diff_type}
140-
(FiniteDiff.finite_difference_derivative(f, x, diff_type, eltype(x), dir = dir), 2)
141-
end
142-
function jacobian_finitediff(f, x::AbstractArray, ::Type{diff_type}, dir, colorvec,
143-
sparsity, jac_prototype) where {diff_type}
144-
f_in = diff_type === Val{:forward} ? f(x) : similar(x)
145-
ret_eltype = eltype(f_in)
146-
J = FiniteDiff.finite_difference_jacobian(f, x, diff_type, ret_eltype, f_in,
147-
dir = dir, colorvec = colorvec,
148-
sparsity = sparsity,
149-
jac_prototype = jac_prototype)
150-
return J, _nfcount(maximum(colorvec), diff_type)
151-
end
152-
function jacobian(cache, f::F) where {F}
153-
x = cache.u
154-
alg = cache.alg
155-
uf = cache.uf
156-
local tmp
157-
158-
if DiffEqBase.has_jac(cache.f)
159-
J = f.jac(cache.u, cache.p)
160-
elseif alg_autodiff(alg)
161-
J, tmp = jacobian_autodiff(uf, x, cache.f, alg)
162-
else
163-
jac_prototype = cache.f.jac_prototype
164-
sparsity, colorvec = sparsity_colorvec(cache.f, x)
165-
dir = true
166-
J, tmp = jacobian_finitediff(uf, x, alg_difftype(alg), dir, colorvec, sparsity,
167-
jac_prototype)
168-
end
169-
J
170-
end
171-
172-
jacobian_autodiff(f, x, nonlinfun, alg) = (ForwardDiff.derivative(f, x), 1, alg)
173-
function jacobian_autodiff(f, x::AbstractArray, nonlinfun, alg)
174-
jac_prototype = nonlinfun.jac_prototype
175-
sparsity, colorvec = sparsity_colorvec(nonlinfun, x)
176-
maxcolor = maximum(colorvec)
177-
chunk_size = get_chunksize(alg) === Val(0) ? nothing : get_chunksize(alg)
178-
num_of_chunks = chunk_size === nothing ?
179-
Int(ceil(maxcolor /
180-
SparseDiffTools.getsize(ForwardDiff.pickchunksize(maxcolor)))) :
181-
Int(ceil(maxcolor / _unwrap_val(chunk_size)))
182-
(forwarddiff_color_jacobian(f, x, colorvec = colorvec, sparsity = sparsity,
183-
jac_prototype = jac_prototype, chunksize = chunk_size),
184-
num_of_chunks)
77+
return uf, linsolve, J, fu, jac_cache
18578
end

0 commit comments

Comments
 (0)