Skip to content

Commit 1d0c424

Browse files
Handle non-vector inputs
1 parent 2fd9480 commit 1d0c424

File tree

8 files changed

+37
-11
lines changed

8 files changed

+37
-11
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ PrecompileTools = "1"
5050
RecursiveArrayTools = "2"
5151
Reexport = "0.2, 1"
5252
SciMLBase = "2.4"
53-
SimpleNonlinearSolve = "0.1.22"
53+
SimpleNonlinearSolve = "0.1.23"
5454
SparseDiffTools = "2.6"
5555
StaticArraysCore = "1.4"
5656
UnPack = "1.0"

src/NonlinearSolve.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ if isdefined(Base, :Experimental) && isdefined(Base.Experimental, Symbol("@max_m
55
end
66

77
using DiffEqBase, LinearAlgebra, LinearSolve, SparseArrays, SparseDiffTools
8+
import ArrayInterface: restructure
89
import ForwardDiff
910

1011
import ADTypes: AbstractFiniteDifferencesMode

src/jacobian.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii
9292
if needsJᵀJ
9393
JᵀJ = __init_JᵀJ(J)
9494
# FIXME: This needs to be handled better for JacVec Operator
95-
Jᵀfu = J' * fu
95+
Jᵀfu = J' * _vec(fu)
9696
end
9797

9898
linprob = LinearProblem(needsJᵀJ ? __maybe_symmetric(JᵀJ) : J,

src/levenberg.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
172172
else
173173
d = similar(u)
174174
d .= min_damping_D
175-
DᵀD = Diagonal(d)
175+
DᵀD = Diagonal(_vec(d))
176176
end
177177

178178
loss = internalnorm(fu1)
@@ -289,7 +289,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
289289
cache.v = -cache.mat_tmp \ (J' * fu1)
290290
else
291291
linres = dolinsolve(alg.precs, linsolve; A = -__maybe_symmetric(cache.mat_tmp),
292-
b = _vec(J' * fu1), linu = _vec(cache.v), p, reltol = cache.abstol)
292+
b = _vec(J' * _vec(fu1)), linu = _vec(cache.v), p, reltol = cache.abstol)
293293
cache.linsolve = linres.cache
294294
end
295295

@@ -301,7 +301,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
301301
else
302302
linres = dolinsolve(alg.precs, linsolve;
303303
b = _mutable(_vec(J' *
304-
((2 / h) .* ((f(u .+ h .* v, p) .- fu1) ./ h .- J * v)))),
304+
_vec(((2 / h) .* (_vec(f(u .+ h .* _restructure(u,v), p)) .- _vec(fu1) ./ h .- J * _vec(v)))))),
305305
linu = _vec(cache.a), p, reltol = cache.abstol)
306306
cache.linsolve = linres.cache
307307
end

src/trustRegion.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ function perform_step!(cache::TrustRegionCache{true})
347347
if cache.make_new_J
348348
jacobian!!(J, cache)
349349
mul!(cache.H, J', J)
350-
mul!(cache.g, J', fu)
350+
mul!(_vec(cache.g), J', _vec(fu))
351351
cache.stats.njacs += 1
352352

353353
# do not use A = cache.H, b = _vec(cache.g) since it is equivalent
@@ -378,9 +378,9 @@ function perform_step!(cache::TrustRegionCache{false})
378378
if make_new_J
379379
J = jacobian!!(cache.J, cache)
380380
cache.H = J' * J
381-
cache.g = J' * fu
381+
cache.g = _restructure(fu, J' * _vec(fu))
382382
cache.stats.njacs += 1
383-
cache.u_gauss_newton = -1 .* (cache.H \ cache.g)
383+
cache.u_gauss_newton = -1 .* _restructure(cache.g, cache.H \ _vec(cache.g))
384384
end
385385

386386
# Compute the Newton step.
@@ -419,7 +419,7 @@ function trust_region_step!(cache::TrustRegionCache)
419419
cache.loss_new = get_loss(fu_new)
420420

421421
# Compute the ratio of the actual reduction to the predicted reduction.
422-
cache.r = -(loss - cache.loss_new) / (dot(du, g) + dot(du, H, du) / 2)
422+
cache.r = -(loss - cache.loss_new) / (dot(_vec(du), _vec(g)) + dot(_vec(du), H, _vec(du)) / 2)
423423
@unpack r = cache
424424

425425
if radius_update_scheme === RadiusUpdateSchemes.Simple
@@ -597,7 +597,7 @@ function dogleg!(cache::TrustRegionCache{true})
597597

598598
# Take intersection of steepest descent direction and trust region if Cauchy point lies outside of trust region
599599
l_grad = norm(cache.g) # length of the gradient
600-
d_cauchy = l_grad^3 / dot(cache.g, cache.H, cache.g) # distance of the cauchy point from the current iterate
600+
d_cauchy = l_grad^3 / dot(_vec(cache.g), cache.H, _vec(cache.g)) # distance of the cauchy point from the current iterate
601601
if d_cauchy >= trust_r
602602
@. cache.du = -(trust_r / l_grad) * cache.g # step to the end of the trust region
603603
return
@@ -627,7 +627,7 @@ function dogleg!(cache::TrustRegionCache{false})
627627

628628
## Take intersection of steepest descent direction and trust region if Cauchy point lies outside of trust region
629629
l_grad = norm(cache.g)
630-
d_cauchy = l_grad^3 / dot(cache.g, cache.H, cache.g) # distance of the cauchy point from the current iterate
630+
d_cauchy = l_grad^3 / dot(_vec(cache.g), cache.H, _vec(cache.g)) # distance of the cauchy point from the current iterate
631631
if d_cauchy > trust_r # cauchy point lies outside of trust region
632632
cache.du = -(trust_r / l_grad) * cache.g # step to the end of the trust region
633633
return

src/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ end
7474
@inline _vec(v::Number) = v
7575
@inline _vec(v::AbstractVector) = v
7676

77+
@inline _restructure(y,x) = restructure(y,x)
78+
@inline _restructure(y::Number,x::Number) = x
79+
7780
DEFAULT_PRECS(W, du, u, p, t, newW, Plprev, Prprev, cachedata) = nothing, nothing
7881

7982
function dolinsolve(precs::P, linsolve; A = nothing, linu = nothing, b = nothing,

test/matrix_resizing.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using NonlinearSolve, Test
2+
3+
ff(u, p) = u .* u .- p
4+
u0 = rand(2,2)
5+
p = 2.0
6+
vecprob = NonlinearProblem(ff, vec(u0), p)
7+
prob = NonlinearProblem(ff, u0, p)
8+
9+
for alg in (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), PseudoTransient(), RobustMultiNewton(), FastShortcutNonlinearPolyalg())
10+
@test vec(solve(prob, alg).u) == solve(vecprob, alg).u
11+
end
12+
13+
fiip(du, u, p) = (du .= u .* u .- p)
14+
u0 = rand(2,2)
15+
p = 2.0
16+
vecprob = NonlinearProblem(fiip, vec(u0), p)
17+
prob = NonlinearProblem(fiip, u0, p)
18+
19+
for alg in (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), PseudoTransient(), RobustMultiNewton(), FastShortcutNonlinearPolyalg())
20+
@test vec(solve(prob, alg).u) == solve(vecprob, alg).u
21+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ end
1616
@time @safetestset "Basic Tests + Some AD" include("basictests.jl")
1717
@time @safetestset "Sparsity Tests" include("sparse.jl")
1818
@time @safetestset "Polyalgs" include("polyalgs.jl")
19+
@time @safetestset "Matrix Resizing" include("matrix_resizing.jl")
1920
@time @safetestset "Nonlinear Least Squares" include("nonlinear_least_squares.jl")
2021
end
2122

0 commit comments

Comments
 (0)