Skip to content

Commit 1364deb

Browse files
tests passing
1 parent c17ac30 commit 1364deb

File tree

3 files changed

+60
-34
lines changed

3 files changed

+60
-34
lines changed

src/jacobian.jl

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,12 @@
1-
mutable struct JacobianWrapper{fType, pType}
1+
struct JacobianWrapper{fType, pType}
22
f::fType
33
p::pType
44
end
55

66
(uf::JacobianWrapper)(u) = uf.f(u, uf.p)
77
(uf::JacobianWrapper)(res, u) = uf.f(res, u, uf.p)
88

9-
struct ImmutableJacobianWrapper{fType, pType}
10-
f::fType
11-
p::pType
12-
end
13-
14-
(uf::ImmutableJacobianWrapper)(u) = uf.f(u, uf.p)
9+
struct NonlinearSolveTag end
1510

1611
function sparsity_colorvec(f, x)
1712
sparsity = f.sparsity
@@ -21,33 +16,35 @@ function sparsity_colorvec(f, x)
2116
end
2217

2318
function jacobian_finitediff_forward!(J, f, x, jac_config, forwardcache, cache)
24-
(FiniteDiff.finite_difference_jacobian!(J, f, x, jac_config, forwardcache,
25-
dir = diffdir(cache));
19+
(FiniteDiff.finite_difference_jacobian!(J, f, x, jac_config, forwardcache);
2620
maximum(jac_config.colorvec))
2721
end
2822
function jacobian_finitediff!(J, f, x, jac_config, cache)
29-
(FiniteDiff.finite_difference_jacobian!(J, f, x, jac_config,
30-
dir = diffdir(cache));
23+
(FiniteDiff.finite_difference_jacobian!(J, f, x, jac_config);
3124
2 * maximum(jac_config.colorvec))
3225
end
3326

34-
function jacobian!(J::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number},
35-
fx::AbstractArray{<:Number}, cache,
36-
jac_config)
37-
alg = unwrap_alg(cache, true)
27+
function jacobian!(J::AbstractMatrix{<:Number}, cache)
28+
f = cache.f
29+
uf = cache.uf
30+
x = cache.u
31+
fx = cache.fu
32+
jac_config = cache.jac_config
33+
alg = cache.alg
34+
3835
if alg_autodiff(alg)
39-
forwarddiff_color_jacobian!(J, f, x, jac_config)
36+
forwarddiff_color_jacobian!(J, uf, x, jac_config)
4037
#cache.destats.nf += 1
4138
else
4239
isforward = alg_difftype(alg) === Val{:forward}
4340
if isforward
4441
forwardcache = get_tmp_cache(cache, alg, unwrap_cache(cache, true))[2]
45-
f(forwardcache, x)
42+
uf(forwardcache, x)
4643
#cache.destats.nf += 1
47-
tmp = jacobian_finitediff_forward!(J, f, x, jac_config, forwardcache,
44+
tmp = jacobian_finitediff_forward!(J, uf, x, jac_config, forwardcache,
4845
cache)
4946
else # not forward difference
50-
tmp = jacobian_finitediff!(J, f, x, jac_config, cache)
47+
tmp = jacobian_finitediff!(J, uf, x, jac_config, cache)
5148
end
5249
cache.destats.nf += tmp
5350
end
@@ -57,7 +54,7 @@ end
5754
function build_jac_config(alg, f::F1, uf::F2, du1, u, tmp, du2) where {F1, F2}
5855
haslinsolve = hasfield(typeof(alg), :linsolve)
5956

60-
if SciMLBase.has_jac(f) && # No Jacobian if has analytical solution
57+
if !SciMLBase.has_jac(f) && # No Jacobian if has analytical solution
6158
((concrete_jac(alg) === nothing && (!haslinsolve || (haslinsolve && # No Jacobian if linsolve doesn't want it
6259
(alg.linsolve === nothing || LinearSolve.needs_concrete_A(alg.linsolve))))) ||
6360
(concrete_jac(alg) !== nothing && concrete_jac(alg))) # Jacobian if explicitly asked for
@@ -68,7 +65,7 @@ function build_jac_config(alg, f::F1, uf::F2, du1, u, tmp, du2) where {F1, F2}
6865
_chunksize = get_chunksize(alg) === Val(0) ? nothing : get_chunksize(alg) # SparseDiffEq uses different convection...
6966

7067
T = if standardtag(alg)
71-
typeof(ForwardDiff.Tag(OrdinaryDiffEqTag(), eltype(u)))
68+
typeof(ForwardDiff.Tag(NonlinearSolveTag(), eltype(u)))
7269
else
7370
typeof(ForwardDiff.Tag(uf, eltype(u)))
7471
end
@@ -112,19 +109,23 @@ function jacobian_finitediff(f, x::AbstractArray, ::Type{diff_type}, dir, colorv
112109
jac_prototype = jac_prototype)
113110
return J, _nfcount(maximum(colorvec), diff_type)
114111
end
115-
function jacobian(f, x, cache)
116-
alg = unwrap_alg(cache, true)
112+
function jacobian(cache, f::F) where F
113+
x = cache.u
114+
alg = cache.alg
115+
uf = cache.uf
117116
local tmp
118-
if alg_autodiff(alg)
119-
J, tmp = jacobian_autodiff(f, x, cache.f, alg)
117+
118+
if DiffEqBase.has_jac(cache.f)
119+
J = f.jac(cache.u, cache.p)
120+
elseif alg_autodiff(alg)
121+
J, tmp = jacobian_autodiff(uf, x, cache.f, alg)
120122
else
121123
jac_prototype = cache.f.jac_prototype
122124
sparsity, colorvec = sparsity_colorvec(cache.f, x)
123-
dir = diffdir(cache)
124-
J, tmp = jacobian_finitediff(f, x, alg_difftype(alg), dir, colorvec, sparsity,
125+
dir = true
126+
J, tmp = jacobian_finitediff(uf, x, alg_difftype(alg), dir, colorvec, sparsity,
125127
jac_prototype)
126128
end
127-
cache.destats.nf += tmp
128129
J
129130
end
130131

@@ -135,7 +136,7 @@ function jacobian_autodiff(f, x::AbstractArray, nonlinfun, alg)
135136
maxcolor = maximum(colorvec)
136137
chunk_size = get_chunksize(alg) === Val(0) ? nothing : get_chunksize(alg)
137138
num_of_chunks = chunk_size === nothing ?
138-
Int(ceil(maxcolor / getsize(ForwardDiff.pickchunksize(maxcolor)))) :
139+
Int(ceil(maxcolor / SparseDiffTools.getsize(ForwardDiff.pickchunksize(maxcolor)))) :
139140
Int(ceil(maxcolor / _unwrap_val(chunk_size)))
140141
(forwarddiff_color_jacobian(f, x, colorvec = colorvec, sparsity = sparsity,
141142
jac_prototype = jac_prototype, chunksize = chunk_size),

src/raphson.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ function jacobian_caches(alg::NewtonRaphson, f, u, p, ::Val{true})
7373
end
7474

7575
function jacobian_caches(alg::NewtonRaphson, f, u, p, ::Val{false})
76-
nothing, nothing, nothing, nothing, nothing
76+
JacobianWrapper(f,p), nothing, nothing, nothing, nothing
7777
end
7878

7979
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::NewtonRaphson,
@@ -106,7 +106,7 @@ end
106106
function perform_step!(cache::NewtonRaphsonCache{true})
107107
@unpack u, fu, f, p, alg = cache
108108
@unpack J, linsolve, du1 = cache
109-
calc_J!(J, cache, cache)
109+
jacobian!(J, cache)
110110

111111
# u = u - J \ fu
112112
linres = dolinsolve(alg.precs, linsolve, A = J, b = fu, linu = du1,
@@ -123,10 +123,9 @@ end
123123

124124
function perform_step!(cache::NewtonRaphsonCache{false})
125125
@unpack u, fu, f, p = cache
126-
J = calc_J(cache, ImmutableJacobianWrapper(f, p))
126+
J = jacobian(cache, f)
127127
cache.u = u - J \ fu
128-
fu = f(cache.u, p)
129-
cache.fu = fu
128+
cache.fu = f(cache.u, p)
130129
if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol
131130
cache.force_stop = true
132131
end

src/utils.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,21 @@ function alg_difftype(alg::AbstractNewtonAlgorithm{CS, AD, FDT, ST, CJ}) where {
4848
FDT
4949
end
5050

51+
function concrete_jac(alg::AbstractNewtonAlgorithm{CS, AD, FDT, ST, CJ}) where {CS, AD, FDT,
52+
ST, CJ}
53+
CJ
54+
end
55+
56+
function get_chunksize(alg::AbstractNewtonAlgorithm{CS, AD, FDT, ST, CJ}) where {CS, AD, FDT,
57+
ST, CJ}
58+
Val(CS)
59+
end
60+
61+
function standardtag(alg::AbstractNewtonAlgorithm{CS, AD, FDT, ST, CJ}) where {CS, AD, FDT,
62+
ST, CJ}
63+
ST
64+
end
65+
5166
DEFAULT_PRECS(W, du, u, p, t, newW, Plprev, Prprev, cachedata) = nothing, nothing
5267

5368
function dolinsolve(precs::P, linsolve; A = nothing, linu = nothing, b = nothing,
@@ -97,3 +112,14 @@ function wrapprecs(_Pl, _Pr, weight)
97112
end
98113
Pl, Pr
99114
end
115+
116+
function _nfcount(N, ::Type{diff_type}) where {diff_type}
117+
if diff_type === Val{:complex}
118+
tmp = N
119+
elseif diff_type === Val{:forward}
120+
tmp = N + 1
121+
else
122+
tmp = 2N
123+
end
124+
tmp
125+
end

0 commit comments

Comments
 (0)