Skip to content

Commit 4fe75e8

Browse files
committed
Incorporate upstream changes in NonlinearSolve.jl
1 parent 8869f80 commit 4fe75e8

File tree

6 files changed

+346
-332
lines changed

6 files changed

+346
-332
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@ Manifest.toml
2525
docs/src/assets/Project.toml
2626

2727
.vscode
28+
wip

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ LinearSolve = "2"
3333
PrecompileTools = "1"
3434
RecursiveArrayTools = "2"
3535
Reexport = "0.2, 1"
36-
SciMLBase = "1.92.4"
36+
SciMLBase = "1.97"
3737
SimpleNonlinearSolve = "0.1"
3838
SparseDiffTools = "1, 2"
3939
StaticArraysCore = "1.4"

src/NonlinearSolve.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,14 @@ using DiffEqBase, LinearAlgebra, LinearSolve, SparseDiffTools
88
import ForwardDiff
99

1010
import ADTypes: AbstractFiniteDifferencesMode
11-
import ArrayInterface: undefmatrix
11+
import ArrayInterface: undefmatrix, matrix_colors
1212
import ConcreteStructs: @concrete
1313
import EnumX: @enumx
1414
import ForwardDiff: Dual
1515
import LinearSolve: ComposePreconditioner, InvPreconditioner, needs_concrete_A
1616
import RecursiveArrayTools: AbstractVectorOfArray, recursivecopy!, recursivefill!
1717
import Reexport: @reexport
1818
import SciMLBase: AbstractNonlinearAlgorithm, NLStats, _unwrap_val, has_jac, isinplace
19-
import SparseDiffTools: __init_𝒥
2019
import StaticArraysCore: StaticArray, SVector
2120
import UnPack: @unpack
2221

src/jacobian.jl

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,27 @@ end
66
(uf::JacobianWrapper)(u) = uf.f(u, uf.p)
77
(uf::JacobianWrapper)(res, u) = uf.f(res, u, uf.p)
88

9-
# function sparsity_colorvec(f, x)
10-
# sparsity = f.sparsity
11-
# colorvec = DiffEqBase.has_colorvec(f) ? f.colorvec :
12-
# (isnothing(sparsity) ? (1:length(x)) : matrix_colors(sparsity))
13-
# sparsity, colorvec
14-
# end
9+
# FIXME: This is a deviation from older versions. Previously if sparsity and colorvec were
10+
# provided we would use a sparse AD. Right now it requires an explicit specification
11+
sparsity_detection_alg(f, ad) = NoSparsityDetection()
12+
function sparsity_detection_alg(f, ad::AbstractSparseADType)
13+
if f.sparsity === nothing
14+
if f.jac_prototype === nothing
15+
return SymbolicsSparsityDetection()
16+
else
17+
jac_prototype = f.jac_prototype
18+
end
19+
else
20+
jac_prototype = f.sparsity
21+
end
22+
23+
if SciMLBase.has_colorvec(f)
24+
return PrecomputedJacobianColorvec(; jac_prototype, f.colorvec,
25+
partition_by_rows = ad isa ADTypes.AbstractSparseReverseMode)
26+
else
27+
return JacPrototypeSparsityDetection(; jac_prototype)
28+
end
29+
end
1530

1631
# NoOp for Jacobian if it is not a Abstract Array -- For eg, JacVec Operator
1732
jacobian!!(J, _) = J
@@ -41,14 +56,13 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p,
4156
needs_concrete_A(alg.linsolve)))))
4257
alg_wants_jac = (concrete_jac(alg) === nothing && concrete_jac(alg))
4358

44-
fu = zero(u) # TODO: Use Prototype
59+
# NOTE: The deepcopy is needed here since we are using the resid_prototype elsewhere
60+
fu = f.resid_prototype === nothing ? (iip ? zero(u) : f(u, p)) :
61+
deepcopy(f.resid_prototype)
4562
if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac)
46-
# TODO: We need an Upstream Mode to allow using known sparsity and colorvec
47-
# TODO: We can use the jacobian prototype here
48-
sd = typeof(alg.ad) <: AbstractSparseADType ? SymbolicsSparsityDetection() :
49-
NoSparsityDetection()
63+
sd = sparsity_detection_alg(f, alg.ad)
5064
jac_cache = iip ? sparse_jacobian_cache(alg.ad, sd, uf, fu, u) :
51-
sparse_jacobian_cache(alg.ad, sd, uf, u; fx=fu)
65+
sparse_jacobian_cache(alg.ad, sd, uf, u; fx = fu)
5266
else
5367
jac_cache = nothing
5468
end
@@ -60,12 +74,12 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p,
6074
if has_analytic_jac
6175
iip ? undefmatrix(u) : nothing
6276
else
63-
f.jac_prototype === nothing ? __init_𝒥(jac_cache) : f.jac_prototype
77+
f.jac_prototype === nothing ? init_jacobian(jac_cache) : f.jac_prototype
6478
end
6579
end
6680

67-
# FIXME: Assumes same sized `u` and `fu` -- Incorrect Assumption for Levenberg
68-
linprob = LinearProblem(J, _vec(zero(u)); u0 = _vec(zero(u)))
81+
du = zero(u)
82+
linprob = LinearProblem(J, _vec(fu); u0 = _vec(du))
6983

7084
weight = similar(u)
7185
recursivefill!(weight, true)
@@ -74,5 +88,5 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p,
7488
nothing)..., weight)
7589
linsolve = init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr)
7690

77-
return uf, linsolve, J, fu, jac_cache
91+
return uf, linsolve, J, fu, jac_cache, du
7892
end

0 commit comments

Comments
 (0)