Skip to content

Commit 445e97b

Browse files
committed
Trust Region mostly works
1 parent 13e590e commit 445e97b

File tree

5 files changed

+245
-344
lines changed

5 files changed

+245
-344
lines changed

src/NonlinearSolve.jl

Lines changed: 49 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ include("trace.jl")
169169
include("extension_algs.jl")
170170
include("linesearch.jl")
171171
include("raphson.jl")
172-
# include("trustRegion.jl")
172+
include("trustRegion.jl")
173173
include("levenberg.jl")
174174
include("gaussnewton.jl")
175175
include("dfsane.jl")
@@ -179,54 +179,54 @@ include("klement.jl")
179179
include("lbroyden.jl")
180180
include("jacobian.jl")
181181
include("ad.jl")
182-
# include("default.jl")
183-
184-
# @setup_workload begin
185-
# nlfuncs = ((NonlinearFunction{false}((u, p) -> u .* u .- p), 0.1),
186-
# (NonlinearFunction{false}((u, p) -> u .* u .- p), [0.1]),
187-
# (NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p), [0.1]))
188-
# probs_nls = NonlinearProblem[]
189-
# for T in (Float32, Float64), (fn, u0) in nlfuncs
190-
# push!(probs_nls, NonlinearProblem(fn, T.(u0), T(2)))
191-
# end
192-
193-
# nls_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), PseudoTransient(),
194-
# GeneralBroyden(), GeneralKlement(), DFSane(), nothing)
195-
196-
# probs_nlls = NonlinearLeastSquaresProblem[]
197-
# nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), [0.1, 0.0]),
198-
# (NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)), [0.1, 0.1]),
199-
# (NonlinearFunction{true}((du, u, p) -> du[1] = u[1] * u[1] - p,
200-
# resid_prototype = zeros(1)), [0.1, 0.0]),
201-
# (NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p),
202-
# resid_prototype = zeros(4)), [0.1, 0.1]))
203-
# for (fn, u0) in nlfuncs
204-
# push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0))
205-
# end
206-
# nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), Float32[0.1, 0.0]),
207-
# (NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)),
208-
# Float32[0.1, 0.1]),
209-
# (NonlinearFunction{true}((du, u, p) -> du[1] = u[1] * u[1] - p,
210-
# resid_prototype = zeros(Float32, 1)), Float32[0.1, 0.0]),
211-
# (NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p),
212-
# resid_prototype = zeros(Float32, 4)), Float32[0.1, 0.1]))
213-
# for (fn, u0) in nlfuncs
214-
# push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0f0))
215-
# end
216-
217-
# nlls_algs = (LevenbergMarquardt(), GaussNewton(),
218-
# LevenbergMarquardt(; linsolve = LUFactorization()),
219-
# GaussNewton(; linsolve = LUFactorization()))
220-
221-
# @compile_workload begin
222-
# for prob in probs_nls, alg in nls_algs
223-
# solve(prob, alg, abstol = 1e-2)
224-
# end
225-
# for prob in probs_nlls, alg in nlls_algs
226-
# solve(prob, alg, abstol = 1e-2)
227-
# end
228-
# end
229-
# end
182+
include("default.jl")
183+
184+
@setup_workload begin
185+
nlfuncs = ((NonlinearFunction{false}((u, p) -> u .* u .- p), 0.1),
186+
(NonlinearFunction{false}((u, p) -> u .* u .- p), [0.1]),
187+
(NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p), [0.1]))
188+
probs_nls = NonlinearProblem[]
189+
for T in (Float32, Float64), (fn, u0) in nlfuncs
190+
push!(probs_nls, NonlinearProblem(fn, T.(u0), T(2)))
191+
end
192+
193+
nls_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), PseudoTransient(),
194+
GeneralBroyden(), GeneralKlement(), DFSane(), nothing)
195+
196+
probs_nlls = NonlinearLeastSquaresProblem[]
197+
nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), [0.1, 0.0]),
198+
(NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)), [0.1, 0.1]),
199+
(NonlinearFunction{true}((du, u, p) -> du[1] = u[1] * u[1] - p,
200+
resid_prototype = zeros(1)), [0.1, 0.0]),
201+
(NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p),
202+
resid_prototype = zeros(4)), [0.1, 0.1]))
203+
for (fn, u0) in nlfuncs
204+
push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0))
205+
end
206+
nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), Float32[0.1, 0.0]),
207+
(NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)),
208+
Float32[0.1, 0.1]),
209+
(NonlinearFunction{true}((du, u, p) -> du[1] = u[1] * u[1] - p,
210+
resid_prototype = zeros(Float32, 1)), Float32[0.1, 0.0]),
211+
(NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p),
212+
resid_prototype = zeros(Float32, 4)), Float32[0.1, 0.1]))
213+
for (fn, u0) in nlfuncs
214+
push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0f0))
215+
end
216+
217+
nlls_algs = (LevenbergMarquardt(), GaussNewton(),
218+
LevenbergMarquardt(; linsolve = LUFactorization()),
219+
GaussNewton(; linsolve = LUFactorization()))
220+
221+
@compile_workload begin
222+
for prob in probs_nls, alg in nls_algs
223+
solve(prob, alg, abstol = 1e-2)
224+
end
225+
for prob in probs_nlls, alg in nlls_algs
226+
solve(prob, alg, abstol = 1e-2)
227+
end
228+
end
229+
end
230230

231231
export RadiusUpdateSchemes
232232

src/gaussnewton.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ function perform_step!(cache::GaussNewtonCache{iip}) where {iip}
116116

117117
# Use normal form to solve the Linear Problem
118118
if cache.JᵀJ !== nothing
119-
__update_JᵀJ!(cache, Val(:JᵀJ))
120-
__update_Jᵀf!(cache, Val(:JᵀJ))
119+
__update_JᵀJ!(cache)
120+
__update_Jᵀf!(cache)
121121
A, b = __maybe_symmetric(cache.JᵀJ), _vec(cache.Jᵀf)
122122
else
123123
A, b = cache.J, _vec(cache.fu)

src/jacobian.jl

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number,
138138
kwargs...) where {needsJᵀJ, F}
139139
# NOTE: Scalar `u` assumes scalar output from `f`
140140
uf = SciMLBase.JacobianWrapper{false}(f, p)
141-
return uf, FakeLinearSolveJLCache(u, u), u, nothing, nothing, u, u, u
141+
return uf, FakeLinearSolveJLCache(u, u), u, zero(u), nothing, u, u, u
142142
end
143143

144144
# Linear Solve Cache
@@ -208,27 +208,49 @@ function __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf)
208208
end
209209
end
210210

211+
# jvp fallback scalar
212+
__jacvec(args...; kwargs...) = JacVec(args...; kwargs...)
213+
function __jacvec(uf, u::Number; autodiff, kwargs...)
214+
@assert autodiff isa AutoForwardDiff "Only ForwardDiff is currently supported."
215+
return JVPScalar(uf, u, autodiff)
216+
end
217+
218+
@concrete mutable struct JVPScalar
219+
uf
220+
u
221+
autodiff
222+
end
223+
224+
function Base.:*(jvp::JVPScalar, v)
225+
T = typeof(ForwardDiff.Tag(typeof(jvp.uf), typeof(jvp.u)))
226+
out = jvp.uf(ForwardDiff.Dual{T}(jvp.u, v))
227+
return ForwardDiff.extract_derivative(T, out)
228+
end
229+
211230
# Generic Handling of Krylov Methods for Normal Form Linear Solves
212-
function __update_JᵀJ!(cache::AbstractNonlinearSolveCache)
231+
function __update_JᵀJ!(cache::AbstractNonlinearSolveCache, J = nothing)
213232
if !(cache.JᵀJ isa KrylovJᵀJ)
214-
@bb cache.JᵀJ = transpose(cache.J) × cache.J
233+
J_ = ifelse(J === nothing, cache.J, J)
234+
@bb cache.JᵀJ = transpose(J_) × J_
215235
end
216236
end
217237

218-
function __update_Jᵀf!(cache::AbstractNonlinearSolveCache)
238+
function __update_Jᵀf!(cache::AbstractNonlinearSolveCache, J = nothing)
219239
if cache.JᵀJ isa KrylovJᵀJ
220240
@bb cache.Jᵀf = cache.JᵀJ.Jᵀ × cache.fu
221241
else
222-
@bb cache.Jᵀf = transpose(cache.J) × vec(cache.fu)
242+
J_ = ifelse(J === nothing, cache.J, J)
243+
@bb cache.Jᵀf = transpose(J_) × vec(cache.fu)
223244
end
224245
end
225246

226247
# Left-Right Multiplication
227-
__lr_mul(::Val, H, g) = dot(g, H, g)
228-
## TODO: Use a cache here to avoid allocations
229-
__lr_mul(::Val{false}, H::KrylovJᵀJ, g) = dot(g, H.JᵀJ, g)
230-
function __lr_mul(::Val{true}, H::KrylovJᵀJ, g)
231-
c = similar(g)
232-
mul!(c, H.JᵀJ, g)
233-
return dot(g, c)
248+
__lr_mul(cache::AbstractNonlinearSolveCache) = __lr_mul(cache, cache.JᵀJ, cache.Jᵀf)
249+
function __lr_mul(cache::AbstractNonlinearSolveCache, JᵀJ::KrylovJᵀJ, Jᵀf)
250+
@bb cache.lr_mul_cache = JᵀJ.JᵀJ × vec(Jᵀf)
251+
return dot(_vec(Jᵀf), _vec(cache.lr_mul_cache))
252+
end
253+
function __lr_mul(cache::AbstractNonlinearSolveCache, JᵀJ, Jᵀf)
254+
@bb cache.lr_mul_cache = JᵀJ × Jᵀf
255+
return dot(_vec(Jᵀf), _vec(cache.lr_mul_cache))
234256
end

0 commit comments

Comments
 (0)