Skip to content

Commit 6cdff35

Browse files
committed
feat: add partial SimpleKlement Implementation
1 parent 24e68e9 commit 6cdff35

File tree

5 files changed

+181
-3
lines changed

5 files changed

+181
-3
lines changed

lib/SimpleNonlinearSolve/Project.toml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,16 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
99
BracketingNonlinearSolve = "70df07ce-3d50-431d-a3e7-ca6ddb60ac1e"
1010
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
1111
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
12+
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
1213
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1314
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
15+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
16+
MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
1417
NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
1518
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1619
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1720
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
21+
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1822

1923
[weakdeps]
2024
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -36,12 +40,24 @@ ChainRulesCore = "1.24"
3640
CommonSolve = "0.2.4"
3741
DiffEqBase = "6.155"
3842
DifferentiationInterface = "0.5.17"
43+
FastClosures = "0.3.2"
3944
FiniteDiff = "2.24.0"
4045
ForwardDiff = "0.10.36"
46+
LinearAlgebra = "1.10"
47+
MaybeInplace = "0.1.4"
4148
NonlinearSolveBase = "1"
4249
PrecompileTools = "1.2"
4350
Reexport = "1.2"
4451
ReverseDiff = "1.15"
4552
SciMLBase = "2.50"
53+
StaticArraysCore = "1.4.3"
4654
Tracker = "0.2.35"
4755
julia = "1.10"
56+
57+
[extras]
58+
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
59+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
60+
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
61+
62+
[targets]
63+
test = ["InteractiveUtils", "Test", "TestItemRunner"]

lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,33 @@
11
module SimpleNonlinearSolve
22

33
using CommonSolve: CommonSolve, solve
4+
using FastClosures: @closure
5+
using MaybeInplace: @bb
46
using PrecompileTools: @compile_workload, @setup_workload
57
using Reexport: @reexport
68
@reexport using SciMLBase # I don't like this but needed to avoid a breaking change
79
using SciMLBase: AbstractNonlinearAlgorithm, NonlinearProblem, ReturnCode
10+
using StaticArraysCore: StaticArray
811

912
# AD Dependencies
10-
using ADTypes: ADTypes, AbstractADType, AutoFiniteDiff, AutoForwardDiff,
11-
AutoPolyesterForwardDiff
13+
using ADTypes: AbstractADType, AutoFiniteDiff, AutoForwardDiff, AutoPolyesterForwardDiff
1214
using DifferentiationInterface: DifferentiationInterface
1315
# TODO: move these to extensions in a breaking change. These are not even used in the
1416
# package, but are used to trigger the extension loading in DI.jl
1517
using FiniteDiff: FiniteDiff
1618
using ForwardDiff: ForwardDiff
1719

1820
using BracketingNonlinearSolve: Alefeld, Bisection, Brent, Falsi, ITP, Ridder
19-
using NonlinearSolveBase: ImmutableNonlinearProblem
21+
using NonlinearSolveBase: ImmutableNonlinearProblem, get_tolerance
2022

2123
const DI = DifferentiationInterface
2224

2325
abstract type AbstractSimpleNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end
2426

2527
is_extension_loaded(::Val) = false
2628

29+
include("utils.jl")
30+
2731
# By Pass the highlevel checks for NonlinearProblem for Simple Algorithms
2832
function CommonSolve.solve(prob::NonlinearProblem,
2933
alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...)
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""
2+
SimpleKlement()
3+
4+
A low-overhead implementation of `Klement` [klement2014using](@citep). This
5+
method is non-allocating on scalar and static array problems.
6+
"""
7+
struct SimpleKlement <: AbstractSimpleNonlinearSolveAlgorithm end
8+
9+
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleKlement, args...;
10+
abstol = nothing, reltol = nothing, maxiters = 1000,
11+
alias_u0 = false, termination_condition = nothing, kwargs...)
12+
x = Utils.maybe_unaliased(prob.u0, alias_u0)
13+
T = eltype(x)
14+
15+
abstol, reltol, tc_cache = NonlinearSolveBase.init_termination_cache(
16+
prob, abstol, reltol, fx, x, termination_condition, Val(:simple))
17+
18+
@bb δx = copy(x)
19+
@bb fprev = copy(fx)
20+
@bb xo = copy(x)
21+
@bb d = copy(x)
22+
23+
J = one.(x)
24+
@bb δx² = similar(x)
25+
26+
for _ in 1:maxiters
27+
any(iszero, J) && (J = Utils.identity_jacobian!!(J))
28+
29+
@bb @. δx = fprev / J
30+
31+
@bb @. x = xo - δx
32+
fx = Utils.eval_f(prob, fx, x)
33+
34+
# Termination Checks
35+
# tc_sol = check_termination(tc_cache, fx, x, xo, prob, alg)
36+
tc_sol !== nothing && return tc_sol
37+
38+
@bb δx .*= -1
39+
@bb @. δx² = δx^2 * J^2
40+
@bb @. J += (fx - fprev - J * δx) / ifelse(iszero(δx²), T(1e-5), δx²) * δx * (J^2)
41+
42+
@bb copyto!(fprev, fx)
43+
@bb copyto!(xo, x)
44+
end
45+
46+
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)
47+
end

lib/SimpleNonlinearSolve/src/utils.jl

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
module Utils
2+
3+
using ADTypes: AbstractADType, AutoForwardDiff, AutoFiniteDiff, AutoPolyesterForwardDiff
4+
using ArrayInterface: ArrayInterface
5+
using DifferentiationInterface: DifferentiationInterface
6+
using FastClosures: @closure
7+
using LinearAlgebra: LinearAlgebra, I, diagind
8+
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem
9+
using SciMLBase: SciMLBase, NonlinearLeastSquaresProblem, NonlinearProblem,
10+
NonlinearFunction
11+
using StaticArraysCore: StaticArray, SArray, SMatrix, SVector
12+
13+
const DI = DifferentiationInterface
14+
15+
const safe_similar = NonlinearSolveBase.Utils.safe_similar
16+
17+
pickchunksize(n::Int) = min(n, 12)
18+
19+
can_dual(::Type{<:Real}) = true
20+
can_dual(::Type) = false
21+
22+
maybe_unaliased(x::Union{Number, SArray}, ::Bool) = x
23+
function maybe_unaliased(x::T, alias::Bool) where {T <: AbstractArray}
24+
(alias || !ArrayInterface.can_setindex(T)) && return x
25+
return copy(x)
26+
end
27+
28+
function get_concrete_autodiff(_, ad::AbstractADType)
29+
DI.check_available(ad) && return ad
30+
error("AD Backend $(ad) is not available. This could be because you haven't loaded the \
31+
actual backend (See [Differentiation Inferface Docs](https://gdalle.github.io/DifferentiationInterface.jl/DifferentiationInterface/stable/) \
32+
for more details) or the backend might not be supported by DifferentiationInferface.jl.")
33+
end
34+
function get_concrete_autodiff(
35+
prob, ad::Union{AutoForwardDiff{nothing}, AutoPolyesterForwardDiff{nothing}})
36+
return get_concrete_autodiff(prob,
37+
ArrayInterface.parameterless_type(ad)(;
38+
chunksize = pickchunksize(length(prob.u0)), ad.tag))
39+
end
40+
function get_concrete_autodiff(prob, ::Nothing)
41+
if can_dual(eltype(prob.u0)) && DI.check_available(AutoForwardDiff())
42+
return AutoForwardDiff(; chunksize = pickchunksize(length(prob.u0)))
43+
end
44+
DI.check_available(AutoFiniteDiff()) && return AutoFiniteDiff()
45+
error("Default AD backends are not available. Please load either FiniteDiff or \
46+
ForwardDiff for default AD selection to work. Else provide a specific AD \
47+
backend (instead of `nothing`) to the solver.")
48+
end
49+
50+
# NOTE: This doesn't initialize the `f(x)` but just returns a buffer of the same size
51+
function get_fx(prob::NonlinearLeastSquaresProblem, x)
52+
if SciMLBase.isinplace(prob) && prob.f.resid_prototype === nothing
53+
error("Inplace NonlinearLeastSquaresProblem requires a `resid_prototype` to be \
54+
specified.")
55+
end
56+
return get_fx(prob.f, x, prob.p)
57+
end
58+
function get_fx(prob::Union{ImmutableNonlinearProblem, NonlinearProblem}, x)
59+
return get_fx(prob.f, x, prob.p)
60+
end
61+
function get_fx(f::NonlinearFunction, x, p)
62+
if SciMLBase.isinplace(f)
63+
f.resid_prototype === nothing && return eltype(x).(f.resid_prototype)
64+
return safe_similar(x)
65+
end
66+
return f(x, p)
67+
end
68+
69+
function eval_f(prob, fx, x)
70+
SciMLBase.isinplace(prob) || return prob.f(x, prob.p)
71+
prob.f(fx, x, prob.p)
72+
return fx
73+
end
74+
75+
function fixed_parameter_function(prob::AbstractNonlinearProblem)
76+
SciMLBase.isinplace(prob) && return @closure (du, u) -> prob.f(du, u, prob.p)
77+
return Base.Fix2(prob.f, prob.p)
78+
end
79+
80+
# __init_identity_jacobian(u::Number, fu, α = true) = oftype(u, α)
81+
# function __init_identity_jacobian(u, fu, α = true)
82+
# J = __similar(u, promote_type(eltype(u), eltype(fu)), length(fu), length(u))
83+
# fill!(J, zero(eltype(J)))
84+
# J[diagind(J)] .= eltype(J)(α)
85+
# return J
86+
# end
87+
# function __init_identity_jacobian(u::StaticArray, fu, α = true)
88+
# S1, S2 = length(fu), length(u)
89+
# J = SMatrix{S1, S2, eltype(u)}(I * α)
90+
# return J
91+
# end
92+
93+
identity_jacobian!!(J::Number) = one(J)
94+
function identity_jacobian!!(J::AbstractVector)
95+
ArrayInterface.can_setindex(J) || return one.(J)
96+
fill!(J, true)
97+
return J
98+
end
99+
function identity_jacobian!!(J::AbstractMatrix)
100+
ArrayInterface.can_setindex(J) || return convert(typeof(J), I)
101+
J[diagind(J)] .= true
102+
return J
103+
end
104+
identity_jacobian!!(::SMatrix{S1, S2, T}) where {S1, S2, T} = SMatrix{S1, S2, T}(I)
105+
identity_jacobian!!(::SVector{S1, T}) where {S1, T} = ones(SVector{S1, T})
106+
107+
end
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
1+
using TestItemRunner, InteractiveUtils
12

3+
@info sprint(InteractiveUtils.versioninfo)
4+
5+
@run_package_tests

0 commit comments

Comments
 (0)