Skip to content

Commit a6c17cb

Browse files
Merge pull request #254 from SciML/nonvector
Handle non-vector inputs
2 parents 2fd9480 + e8778ed commit a6c17cb

File tree

8 files changed

+47
-21
lines changed

8 files changed

+47
-21
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ PrecompileTools = "1"
5050
RecursiveArrayTools = "2"
5151
Reexport = "0.2, 1"
5252
SciMLBase = "2.4"
53-
SimpleNonlinearSolve = "0.1.22"
53+
SimpleNonlinearSolve = "0.1.23"
5454
SparseDiffTools = "2.6"
5555
StaticArraysCore = "1.4"
5656
UnPack = "1.0"

src/NonlinearSolve.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ if isdefined(Base, :Experimental) && isdefined(Base.Experimental, Symbol("@max_m
55
end
66

77
using DiffEqBase, LinearAlgebra, LinearSolve, SparseArrays, SparseDiffTools
8+
import ArrayInterface: restructure
89
import ForwardDiff
910

1011
import ADTypes: AbstractFiniteDifferencesMode

src/jacobian.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii
9292
if needsJᵀJ
9393
JᵀJ = __init_JᵀJ(J)
9494
# FIXME: This needs to be handled better for JacVec Operator
95-
Jᵀfu = J' * fu
95+
Jᵀfu = J' * _vec(fu)
9696
end
9797

9898
linprob = LinearProblem(needsJᵀJ ? __maybe_symmetric(JᵀJ) : J,

src/levenberg.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
172172
else
173173
d = similar(u)
174174
d .= min_damping_D
175-
DᵀD = Diagonal(d)
175+
DᵀD = Diagonal(_vec(d))
176176
end
177177

178178
loss = internalnorm(fu1)
@@ -209,21 +209,21 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
209209

210210
# Usual Levenberg-Marquardt step ("velocity").
211211
# The following lines do: cache.v = -cache.mat_tmp \ cache.u_tmp
212-
mul!(cache.u_tmp, J', fu1)
212+
mul!(_vec(cache.u_tmp), J', _vec(fu1))
213213
@. cache.mat_tmp = JᵀJ + λ * DᵀD
214214
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.mat_tmp),
215215
b = _vec(cache.u_tmp), linu = _vec(cache.du), p = p, reltol = cache.abstol)
216216
cache.linsolve = linres.cache
217-
@. cache.v = -cache.du
217+
_vec(cache.v) .= -1 .* _vec(cache.du)
218218

219219
# Geodesic acceleration (step_size = v + a / 2).
220220
@unpack v, α_geodesic, h = cache
221-
f(cache.fu_tmp, u .+ h .* v, p)
221+
f(cache.fu_tmp, _restructure(u, _vec(u) .+ h .* _vec(v)), p)
222222

223223
# The following lines do: cache.a = -J \ cache.fu_tmp
224-
mul!(cache.Jv, J, v)
224+
mul!(_vec(cache.Jv), J, _vec(v))
225225
@. cache.fu_tmp = (2 / h) * ((cache.fu_tmp - fu1) / h - cache.Jv)
226-
mul!(cache.u_tmp, J', cache.fu_tmp)
226+
mul!(_vec(cache.u_tmp), J', _vec(cache.fu_tmp))
227227
# NOTE: Don't pass `A` in again, since we want to reuse the previous solve
228228
linres = dolinsolve(alg.precs, linsolve; b = _vec(cache.u_tmp),
229229
linu = _vec(cache.du), p = p, reltol = cache.abstol)
@@ -235,7 +235,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
235235
# Require acceptable steps to satisfy the following condition.
236236
norm_v = norm(v)
237237
if 2 * norm(cache.a) α_geodesic * norm_v
238-
@. cache.δ = v + cache.a / 2
238+
_vec(cache.δ) .= _vec(v) .+ _vec(cache.a) ./ 2
239239
@unpack δ, loss_old, norm_v_old, v_old, b_uphill = cache
240240
f(cache.fu_tmp, u .+ δ, p)
241241
cache.stats.nf += 1
@@ -251,7 +251,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
251251
return nothing
252252
end
253253
cache.fu1 .= cache.fu_tmp
254-
cache.v_old .= v
254+
_vec(cache.v_old) .= _vec(v)
255255
cache.norm_v_old = norm_v
256256
cache.loss_old = loss
257257
cache.λ_factor = 1 / cache.damping_decrease_factor
@@ -289,7 +289,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
289289
cache.v = -cache.mat_tmp \ (J' * fu1)
290290
else
291291
linres = dolinsolve(alg.precs, linsolve; A = -__maybe_symmetric(cache.mat_tmp),
292-
b = _vec(J' * fu1), linu = _vec(cache.v), p, reltol = cache.abstol)
292+
b = _vec(J' * _vec(fu1)), linu = _vec(cache.v), p, reltol = cache.abstol)
293293
cache.linsolve = linres.cache
294294
end
295295

@@ -300,8 +300,8 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
300300
_vec(J' * ((2 / h) .* ((f(u .+ h .* v, p) .- fu1) ./ h .- J * v)))
301301
else
302302
linres = dolinsolve(alg.precs, linsolve;
303-
b = _mutable(_vec(J' *
304-
((2 / h) .* ((f(u .+ h .* v, p) .- fu1) ./ h .- J * v)))),
303+
b = _mutable(_vec(J' * #((2 / h) .* ((f(u .+ h .* v, p) .- fu1) ./ h .- J * v)))),
304+
_vec(((2 / h) .* ((_vec(f(u .+ h .* _restructure(u,v), p)) .- _vec(fu1)) ./ h .- J * _vec(v)))))),
305305
linu = _vec(cache.a), p, reltol = cache.abstol)
306306
cache.linsolve = linres.cache
307307
end
@@ -311,7 +311,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
311311
# Require acceptable steps to satisfy the following condition.
312312
norm_v = norm(v)
313313
if 2 * norm(cache.a) α_geodesic * norm_v
314-
cache.δ = v .+ cache.a ./ 2
314+
cache.δ = _restructure(cache.δ,_vec(v) .+ _vec(cache.a) ./ 2)
315315
@unpack δ, loss_old, norm_v_old, v_old, b_uphill = cache
316316
fu_new = f(u .+ δ, p)
317317
cache.stats.nf += 1
@@ -327,7 +327,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
327327
return nothing
328328
end
329329
cache.fu1 = fu_new
330-
cache.v_old = v
330+
cache.v_old = _restructure(cache.v_old,v)
331331
cache.norm_v_old = norm_v
332332
cache.loss_old = loss
333333
cache.λ_factor = 1 / cache.damping_decrease_factor

src/trustRegion.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ function perform_step!(cache::TrustRegionCache{true})
347347
if cache.make_new_J
348348
jacobian!!(J, cache)
349349
mul!(cache.H, J', J)
350-
mul!(cache.g, J', fu)
350+
mul!(_vec(cache.g), J', _vec(fu))
351351
cache.stats.njacs += 1
352352

353353
# do not use A = cache.H, b = _vec(cache.g) since it is equivalent
@@ -378,9 +378,9 @@ function perform_step!(cache::TrustRegionCache{false})
378378
if make_new_J
379379
J = jacobian!!(cache.J, cache)
380380
cache.H = J' * J
381-
cache.g = J' * fu
381+
cache.g = _restructure(fu, J' * _vec(fu))
382382
cache.stats.njacs += 1
383-
cache.u_gauss_newton = -1 .* (cache.H \ cache.g)
383+
cache.u_gauss_newton = -1 .* _restructure(cache.g, cache.H \ _vec(cache.g))
384384
end
385385

386386
# Compute the Newton step.
@@ -419,7 +419,7 @@ function trust_region_step!(cache::TrustRegionCache)
419419
cache.loss_new = get_loss(fu_new)
420420

421421
# Compute the ratio of the actual reduction to the predicted reduction.
422-
cache.r = -(loss - cache.loss_new) / (dot(du, g) + dot(du, H, du) / 2)
422+
cache.r = -(loss - cache.loss_new) / (dot(_vec(du), _vec(g)) + dot(_vec(du), H, _vec(du)) / 2)
423423
@unpack r = cache
424424

425425
if radius_update_scheme === RadiusUpdateSchemes.Simple
@@ -597,7 +597,7 @@ function dogleg!(cache::TrustRegionCache{true})
597597

598598
# Take intersection of steepest descent direction and trust region if Cauchy point lies outside of trust region
599599
l_grad = norm(cache.g) # length of the gradient
600-
d_cauchy = l_grad^3 / dot(cache.g, cache.H, cache.g) # distance of the cauchy point from the current iterate
600+
d_cauchy = l_grad^3 / dot(_vec(cache.g), cache.H, _vec(cache.g)) # distance of the cauchy point from the current iterate
601601
if d_cauchy >= trust_r
602602
@. cache.du = -(trust_r / l_grad) * cache.g # step to the end of the trust region
603603
return
@@ -627,7 +627,7 @@ function dogleg!(cache::TrustRegionCache{false})
627627

628628
## Take intersection of steepest descent direction and trust region if Cauchy point lies outside of trust region
629629
l_grad = norm(cache.g)
630-
d_cauchy = l_grad^3 / dot(cache.g, cache.H, cache.g) # distance of the cauchy point from the current iterate
630+
d_cauchy = l_grad^3 / dot(_vec(cache.g), cache.H, _vec(cache.g)) # distance of the cauchy point from the current iterate
631631
if d_cauchy > trust_r # cauchy point lies outside of trust region
632632
cache.du = -(trust_r / l_grad) * cache.g # step to the end of the trust region
633633
return

src/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ end
7474
@inline _vec(v::Number) = v
7575
@inline _vec(v::AbstractVector) = v
7676

77+
@inline _restructure(y,x) = restructure(y,x)
78+
@inline _restructure(y::Number,x::Number) = x
79+
7780
DEFAULT_PRECS(W, du, u, p, t, newW, Plprev, Prprev, cachedata) = nothing, nothing
7881

7982
function dolinsolve(precs::P, linsolve; A = nothing, linu = nothing, b = nothing,

test/matrix_resizing.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using NonlinearSolve, Test
2+
3+
ff(u, p) = u .* u .- p
4+
u0 = rand(2,2)
5+
p = 2.0
6+
vecprob = NonlinearProblem(ff, vec(u0), p)
7+
prob = NonlinearProblem(ff, u0, p)
8+
9+
for alg in (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), PseudoTransient(), RobustMultiNewton(), FastShortcutNonlinearPolyalg())
10+
@test vec(solve(prob, alg).u) == solve(vecprob, alg).u
11+
end
12+
13+
fiip(du, u, p) = (du .= u .* u .- p)
14+
u0 = rand(2,2)
15+
p = 2.0
16+
vecprob = NonlinearProblem(fiip, vec(u0), p)
17+
prob = NonlinearProblem(fiip, u0, p)
18+
19+
for alg in (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), PseudoTransient(), RobustMultiNewton(), FastShortcutNonlinearPolyalg())
20+
@test vec(solve(prob, alg).u) == solve(vecprob, alg).u
21+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ end
1616
@time @safetestset "Basic Tests + Some AD" include("basictests.jl")
1717
@time @safetestset "Sparsity Tests" include("sparse.jl")
1818
@time @safetestset "Polyalgs" include("polyalgs.jl")
19+
@time @safetestset "Matrix Resizing" include("matrix_resizing.jl")
1920
@time @safetestset "Nonlinear Least Squares" include("nonlinear_least_squares.jl")
2021
end
2122

0 commit comments

Comments
 (0)