Skip to content

Commit e87a82d

Browse files
committed
Patch broken solvers + better testing
1 parent 4fe75e8 commit e87a82d

File tree

10 files changed

+1476
-2070
lines changed

10 files changed

+1476
-2070
lines changed

Project.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
2727
ArrayInterface = "6.0.24, 7"
2828
DiffEqBase = "6"
2929
EnumX = "1"
30+
Enzyme = "0.11"
3031
FiniteDiff = "2"
3132
ForwardDiff = "0.10.3"
3233
LinearSolve = "2"
@@ -38,19 +39,23 @@ SimpleNonlinearSolve = "0.1"
3839
SparseDiffTools = "1, 2"
3940
StaticArraysCore = "1.4"
4041
UnPack = "1.0"
42+
Zygote = "0.6"
4143
julia = "1.6"
4244

4345
[extras]
4446
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
47+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
4548
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
4649
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
4750
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
4851
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
4952
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
5053
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
54+
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
5155
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
5256
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
5357
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
58+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5459

5560
[targets]
56-
test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra"]
61+
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools"]

src/NonlinearSolve.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,16 @@ using DiffEqBase, LinearAlgebra, LinearSolve, SparseDiffTools
88
import ForwardDiff
99

1010
import ADTypes: AbstractFiniteDifferencesMode
11-
import ArrayInterface: undefmatrix, matrix_colors
11+
import ArrayInterface: undefmatrix, matrix_colors, parameterless_type, ismutable
1212
import ConcreteStructs: @concrete
1313
import EnumX: @enumx
1414
import ForwardDiff: Dual
1515
import LinearSolve: ComposePreconditioner, InvPreconditioner, needs_concrete_A
16-
import RecursiveArrayTools: AbstractVectorOfArray, recursivecopy!, recursivefill!
16+
import RecursiveArrayTools: ArrayPartition,
17+
AbstractVectorOfArray, recursivecopy!, recursivefill!
1718
import Reexport: @reexport
1819
import SciMLBase: AbstractNonlinearAlgorithm, NLStats, _unwrap_val, has_jac, isinplace
19-
import StaticArraysCore: StaticArray, SVector
20+
import StaticArraysCore: StaticArray, SVector, SArray, MArray
2021
import UnPack: @unpack
2122

2223
@reexport using ADTypes, SciMLBase, SimpleNonlinearSolve
@@ -33,8 +34,6 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::AbstractNonlinearSolveAl
3334
return solve!(cache)
3435
end
3536

36-
# FIXME: Scalar Case is Completely Broken
37-
3837
include("utils.jl")
3938
include("raphson.jl")
4039
include("trustRegion.jl")

src/jacobian.jl

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,16 @@ function jacobian!!(J::Union{AbstractMatrix{<:Number}, Nothing}, cache)
3636
@unpack f, uf, u, p, jac_cache, alg, fu2 = cache
3737
iip = isinplace(cache)
3838
if iip
39-
has_jac(f) ? f.jac(J, u, p) : sparse_jacobian!(J, alg.ad, jac_cache, uf, fu2, u)
39+
has_jac(f) ? f.jac(J, u, p) :
40+
sparse_jacobian!(J, alg.ad, jac_cache, uf, fu2, _maybe_mutable(u, alg.ad))
4041
else
41-
return has_jac(f) ? f.jac(u, p) : sparse_jacobian!(J, alg.ad, jac_cache, uf, u)
42+
return has_jac(f) ? f.jac(u, p) :
43+
sparse_jacobian!(J, alg.ad, jac_cache, uf, _maybe_mutable(u, alg.ad))
4244
end
43-
return nothing
45+
return J
4446
end
47+
# Scalar case
48+
jacobian!!(::Number, cache) = last(value_derivative(cache.uf, cache.u))
4549

4650
# Build Jacobian Caches
4751
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p,
@@ -54,15 +58,16 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p,
5458
linsolve_needs_jac = (concrete_jac(alg) === nothing &&
5559
(!haslinsolve || (haslinsolve && (alg.linsolve === nothing ||
5660
needs_concrete_A(alg.linsolve)))))
57-
alg_wants_jac = (concrete_jac(alg) === nothing && concrete_jac(alg))
61+
alg_wants_jac = (concrete_jac(alg) !== nothing && concrete_jac(alg))
5862

5963
# NOTE: The deepcopy is needed here since we are using the resid_prototype elsewhere
60-
fu = f.resid_prototype === nothing ? (iip ? zero(u) : f(u, p)) :
61-
deepcopy(f.resid_prototype)
64+
fu = f.resid_prototype === nothing ? (iip ? _mutable_zero(u) : _mutable(f(u, p))) :
65+
(iip ? deepcopy(f.resid_prototype) : f.resid_prototype)
6266
if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac)
6367
sd = sparsity_detection_alg(f, alg.ad)
64-
jac_cache = iip ? sparse_jacobian_cache(alg.ad, sd, uf, fu, u) :
65-
sparse_jacobian_cache(alg.ad, sd, uf, u; fx = fu)
68+
ad = alg.ad
69+
jac_cache = iip ? sparse_jacobian_cache(ad, sd, uf, fu, _maybe_mutable(u, ad)) :
70+
sparse_jacobian_cache(ad, sd, uf, _maybe_mutable(u, ad); fx = fu)
6671
else
6772
jac_cache = nothing
6873
end
@@ -78,7 +83,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p,
7883
end
7984
end
8085

81-
du = zero(u)
86+
du = _mutable_zero(u)
8287
linprob = LinearProblem(J, _vec(fu); u0 = _vec(du))
8388

8489
weight = similar(u)
@@ -90,3 +95,11 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p,
9095

9196
return uf, linsolve, J, fu, jac_cache, du
9297
end
98+
99+
## Special Handling for Scalars
100+
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u::Number, p,
101+
::Val{false})
102+
# NOTE: Scalar `u` assumes scalar output from `f`
103+
uf = JacobianWrapper(f, p)
104+
return uf, nothing, u, nothing, nothing, u
105+
end

0 commit comments

Comments
 (0)