Skip to content

Commit aedd385

Browse files
Merge pull request #201 from avik-pal/ap/jnfk
Add support for JFNK
2 parents e93fc58 + c006772 commit aedd385

File tree

11 files changed

+98
-105
lines changed

11 files changed

+98
-105
lines changed

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,14 @@ julia = "1.6"
4141
[extras]
4242
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
4343
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
44+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
45+
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
4446
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
47+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
4548
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
4649
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
4750
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
4851
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4952

5053
[targets]
51-
test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics"]
54+
test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra"]

docs/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[deps]
22
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
33
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
4+
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
45
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
56
NonlinearSolveMINPACK = "c100e077-885d-495a-a2ea-599e143bf69d"
67
SciMLNLSolve = "e9a6253c-8580-4d32-9898-8661bb511710"
@@ -12,6 +13,7 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
1213
[compat]
1314
BenchmarkTools = "1"
1415
Documenter = "0.27"
16+
LinearSolve = "2"
1517
NonlinearSolve = "1"
1618
NonlinearSolveMINPACK = "0.1"
1719
SciMLNLSolve = "0.1"

docs/src/solvers/BracketingSolvers.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Solves for ``f(t)=0`` in the problem defined by `prob` using the algorithm
77

88
## Recommended Methods
99

10-
`ITP()` is the recommended method for the scalar interval root-finding problems. It is particularly well-suited for cases where the function is smooth and well-behaved; and achieved superlinear convergence while retaining the optimal worst-case performance of the Bisection method. For more details, consult the detailed solver API docs.
10+
`ITP()` is the recommended method for the scalar interval root-finding problems. It is particularly well-suited for cases where the function is smooth and well-behaved; and achieved superlinear convergence while retaining the optimal worst-case performance of the Bisection method. For more details, consult the detailed solver API docs.
1111
`Ridder` is a hybrid method that uses the value of function at the midpoint of the interval to perform an exponential interpolation to the root. This gives a fast convergence with a guaranteed convergence of at most twice the number of iterations as the bisection method.
1212
`Brent` is a combination of the bisection method, the secant method and inverse quadratic interpolation. At every iteration, Brent's method decides which method out of these three is likely to do best, and proceeds by doing a step according to that method. This gives a robust and fast method, which therefore enjoys considerable popularity.
1313

docs/src/tutorials/nonlinear.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,21 @@ uspan = (1.0, 2.0) # brackets
3131
probB = IntervalNonlinearProblem(f, uspan)
3232
sol = solve(probB, Falsi())
3333
```
34+
35+
## Using Jacobian Free Newton Krylov (JNFK) Methods
36+
37+
If we want to solve the first example, without constructing the entire Jacobian
38+
39+
```@example
40+
using NonlinearSolve, LinearSolve
41+
42+
function f!(res, u, p)
43+
@. res = u * u - p
44+
end
45+
u0 = [1.0, 1.0]
46+
p = 2.0
47+
probN = NonlinearProblem(f!, u0, p)
48+
49+
linsolve = LinearSolve.KrylovJL_GMRES()
50+
sol = solve(probN, NewtonRaphson(; linsolve), reltol = 1e-9)
51+
```

src/NonlinearSolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ function SciMLBase.__solve(prob::NonlinearProblem,
3434
end
3535

3636
include("utils.jl")
37-
include("jacobian.jl")
3837
include("raphson.jl")
3938
include("trustRegion.jl")
4039
include("levenberg.jl")
40+
include("jacobian.jl")
4141
include("ad.jl")
4242

4343
import PrecompileTools

src/jacobian.jl

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ function jacobian_finitediff!(J, f, x, jac_config, cache)
2424
2 * maximum(jac_config.colorvec))
2525
end
2626

27+
# NoOp for Jacobian if it is not a Abstract Array -- For eg, JacVec Operator
28+
jacobian!(J, cache) = J
2729
function jacobian!(J::AbstractMatrix{<:Number}, cache)
2830
f = cache.f
2931
uf = cache.uf
@@ -52,14 +54,16 @@ function jacobian!(J::AbstractMatrix{<:Number}, cache)
5254
nothing
5355
end
5456

55-
function build_jac_config(alg, f::F1, uf::F2, du1, u, tmp, du2) where {F1, F2}
57+
function build_jac_and_jac_config(alg, f::F1, uf::F2, du1, u, tmp, du2) where {F1, F2}
5658
haslinsolve = hasfield(typeof(alg), :linsolve)
5759

58-
if !SciMLBase.has_jac(f) && # No Jacobian if has analytical solution
59-
((concrete_jac(alg) === nothing && (!haslinsolve || (haslinsolve && # No Jacobian if linsolve doesn't want it
60-
(alg.linsolve === nothing || LinearSolve.needs_concrete_A(alg.linsolve))))) ||
61-
(concrete_jac(alg) !== nothing && concrete_jac(alg))) # Jacobian if explicitly asked for
62-
jac_prototype = f.jac_prototype
60+
has_analytic_jac = SciMLBase.has_jac(f)
61+
linsolve_needs_jac = (concrete_jac(alg) === nothing &&
62+
(!haslinsolve || (haslinsolve && (alg.linsolve === nothing ||
63+
LinearSolve.needs_concrete_A(alg.linsolve)))))
64+
alg_wants_jac = (concrete_jac(alg) !== nothing && concrete_jac(alg))
65+
66+
if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac)
6367
sparsity, colorvec = sparsity_colorvec(f, u)
6468

6569
if alg_autodiff(alg)
@@ -70,25 +74,55 @@ function build_jac_config(alg, f::F1, uf::F2, du1, u, tmp, du2) where {F1, F2}
7074
else
7175
typeof(ForwardDiff.Tag(uf, eltype(u)))
7276
end
73-
jac_config = ForwardColorJacCache(uf, u, _chunksize; colorvec = colorvec,
74-
sparsity = sparsity, tag = T)
77+
jac_config = ForwardColorJacCache(uf, u, _chunksize; colorvec, sparsity,
78+
tag = T)
7579
else
7680
if alg_difftype(alg) !== Val{:complex}
77-
jac_config = FiniteDiff.JacobianCache(tmp, du1, du2, alg_difftype(alg),
78-
colorvec = colorvec,
79-
sparsity = sparsity)
81+
jac_config = FiniteDiff.JacobianCache(tmp, du1, du2, alg_difftype(alg);
82+
colorvec, sparsity)
8083
else
8184
jac_config = FiniteDiff.JacobianCache(Complex{eltype(tmp)}.(tmp),
82-
Complex{eltype(du1)}.(du1), nothing,
83-
alg_difftype(alg), eltype(u),
84-
colorvec = colorvec,
85-
sparsity = sparsity)
85+
Complex{eltype(du1)}.(du1), nothing, alg_difftype(alg), eltype(u);
86+
colorvec, sparsity)
8687
end
8788
end
8889
else
8990
jac_config = nothing
9091
end
91-
jac_config
92+
93+
J = if !linsolve_needs_jac
94+
# We don't need to construct the Jacobian
95+
JacVec(uf, u; autodiff = alg_autodiff(alg) ? AutoForwardDiff() : AutoFiniteDiff())
96+
else
97+
if f.jac_prototype === nothing
98+
ArrayInterface.undefmatrix(u)
99+
else
100+
f.jac_prototype
101+
end
102+
end
103+
104+
return J, jac_config
105+
end
106+
107+
# Build Jacobian Caches
108+
function jacobian_caches(alg::Union{NewtonRaphson, LevenbergMarquardt, TrustRegion}, f, u,
109+
p, ::Val{true})
110+
uf = JacobianWrapper(f, p)
111+
112+
du1 = zero(u)
113+
du2 = zero(u)
114+
tmp = zero(u)
115+
J, jac_config = build_jac_and_jac_config(alg, f, uf, du1, u, tmp, du2)
116+
117+
linprob = LinearProblem(J, _vec(zero(u)); u0 = _vec(zero(u)))
118+
weight = similar(u)
119+
recursivefill!(weight, true)
120+
121+
Pl, Pr = wrapprecs(alg.precs(J, nothing, u, p, nothing, nothing, nothing, nothing,
122+
nothing)..., weight)
123+
linsolve = init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr)
124+
125+
uf, linsolve, J, du1, jac_config
92126
end
93127

94128
function get_chunksize(jac_config::ForwardDiff.JacobianConfig{

src/levenberg.jl

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -226,27 +226,6 @@ mutable struct LevenbergMarquardtCache{iip, fType, algType, uType, duType, resTy
226226
end
227227
end
228228

229-
function jacobian_caches(alg::LevenbergMarquardt, f, u, p, ::Val{true})
230-
uf = JacobianWrapper(f, p)
231-
J = ArrayInterface.undefmatrix(u)
232-
233-
linprob = LinearProblem(J, _vec(zero(u)); u0 = _vec(zero(u)))
234-
weight = similar(u)
235-
recursivefill!(weight, false)
236-
237-
Pl, Pr = wrapprecs(alg.precs(J, nothing, u, p, nothing, nothing, nothing, nothing,
238-
nothing)..., weight)
239-
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
240-
Pl = Pl, Pr = Pr)
241-
242-
du1 = zero(u)
243-
du2 = zero(u)
244-
tmp = zero(u)
245-
jac_config = build_jac_config(alg, f, uf, du1, u, tmp, du2)
246-
247-
uf, linsolve, J, du1, jac_config
248-
end
249-
250229
function jacobian_caches(alg::LevenbergMarquardt, f, u, p, ::Val{false})
251230
JacobianWrapper(f, p), nothing, ArrayInterface.undefmatrix(u), nothing, nothing
252231
end

src/raphson.jl

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -104,31 +104,6 @@ mutable struct NewtonRaphsonCache{iip, fType, algType, uType, duType, resType, p
104104
end
105105
end
106106

107-
function jacobian_caches(alg::NewtonRaphson, f, u, p, ::Val{true})
108-
uf = JacobianWrapper(f, p)
109-
J = if f.jac_prototype === nothing
110-
ArrayInterface.undefmatrix(u)
111-
else
112-
f.jac_prototype
113-
end
114-
115-
linprob = LinearProblem(J, _vec(zero(u)); u0 = _vec(zero(u)))
116-
weight = similar(u)
117-
recursivefill!(weight, false)
118-
119-
Pl, Pr = wrapprecs(alg.precs(J, nothing, u, p, nothing, nothing, nothing, nothing,
120-
nothing)..., weight)
121-
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
122-
Pl = Pl, Pr = Pr)
123-
124-
du1 = zero(u)
125-
du2 = zero(u)
126-
tmp = zero(u)
127-
jac_config = build_jac_config(alg, f, uf, du1, u, tmp, du2)
128-
129-
uf, linsolve, J, du1, jac_config
130-
end
131-
132107
function jacobian_caches(alg::NewtonRaphson, f, u, p, ::Val{false})
133108
JacobianWrapper(f, p), nothing, nothing, nothing, nothing
134109
end

src/trustRegion.jl

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -278,27 +278,6 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
278278
end
279279
end
280280

281-
function jacobian_caches(alg::TrustRegion, f, u, p, ::Val{true})
282-
uf = JacobianWrapper(f, p)
283-
J = ArrayInterface.undefmatrix(u)
284-
285-
linprob = LinearProblem(J, _vec(zero(u)); u0 = _vec(zero(u)))
286-
weight = similar(u)
287-
recursivefill!(weight, false)
288-
289-
Pl, Pr = wrapprecs(alg.precs(J, nothing, u, p, nothing, nothing, nothing, nothing,
290-
nothing)..., weight)
291-
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
292-
Pl = Pl, Pr = Pr)
293-
294-
du1 = zero(u)
295-
du2 = zero(u)
296-
tmp = zero(u)
297-
jac_config = build_jac_config(alg, f, uf, du1, u, tmp, du2)
298-
299-
uf, linsolve, J, du1, jac_config
300-
end
301-
302281
function jacobian_caches(alg::TrustRegion, f, u, p, ::Val{false})
303282
J = ArrayInterface.undefmatrix(u)
304283
JacobianWrapper(f, p), nothing, J, zero(u), nothing

src/utils.jl

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ function alg_difftype(alg::AbstractNewtonAlgorithm{
5151
FDT,
5252
ST,
5353
CJ,
54-
}) where {CS, AD, FDT,
55-
ST, CJ}
54+
}) where {CS, AD, FDT, ST, CJ}
5655
FDT
5756
end
5857

@@ -62,8 +61,7 @@ function concrete_jac(alg::AbstractNewtonAlgorithm{
6261
FDT,
6362
ST,
6463
CJ,
65-
}) where {CS, AD, FDT,
66-
ST, CJ}
64+
}) where {CS, AD, FDT, ST, CJ}
6765
CJ
6866
end
6967

@@ -73,9 +71,7 @@ function get_chunksize(alg::AbstractNewtonAlgorithm{
7371
FDT,
7472
ST,
7573
CJ,
76-
}) where {CS, AD,
77-
FDT,
78-
ST, CJ}
74+
}) where {CS, AD, FDT, ST, CJ}
7975
Val(CS)
8076
end
8177

@@ -85,17 +81,15 @@ function standardtag(alg::AbstractNewtonAlgorithm{
8581
FDT,
8682
ST,
8783
CJ,
88-
}) where {CS, AD, FDT,
89-
ST, CJ}
84+
}) where {CS, AD, FDT, ST, CJ}
9085
ST
9186
end
9287

9388
DEFAULT_PRECS(W, du, u, p, t, newW, Plprev, Prprev, cachedata) = nothing, nothing
9489

9590
function dolinsolve(precs::P, linsolve; A = nothing, linu = nothing, b = nothing,
96-
du = nothing, u = nothing, p = nothing, t = nothing,
97-
weight = nothing, cachedata = nothing,
98-
reltol = nothing) where {P}
91+
du = nothing, u = nothing, p = nothing, t = nothing, weight = nothing,
92+
cachedata = nothing, reltol = nothing) where {P}
9993
A !== nothing && (linsolve.A = A)
10094
b !== nothing && (linsolve.b = b)
10195
linu !== nothing && (linsolve.u = linu)

0 commit comments

Comments
 (0)