Skip to content

Commit 35ce6e2

Browse files
authored
Merge branch 'master' into u/termination_condition
2 parents f20c9bc + 191a237 commit 35ce6e2

File tree

9 files changed

+64
-44
lines changed

9 files changed

+64
-44
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,17 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
2525
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
2626

2727
[weakdeps]
28+
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
2829
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
2930
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
3031

3132
[extensions]
33+
NonlinearSolveBandedMatricesExt = "BandedMatrices"
3234
NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt"
3335
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"
3436

3537
[compat]
38+
BandedMatrices = "1"
3639
ADTypes = "0.2"
3740
ArrayInterface = "6.0.24, 7"
3841
ConcreteStructs = "0.2"
@@ -58,6 +61,7 @@ Zygote = "0.6"
5861
julia = "1.9"
5962

6063
[extras]
64+
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
6165
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
6266
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
6367
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
@@ -78,4 +82,4 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
7882
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
7983

8084
[targets]
81-
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "DiffEqBase"]
85+
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase"]
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
module NonlinearSolveBandedMatricesExt
2+
3+
using BandedMatrices, LinearAlgebra, NonlinearSolve, SparseArrays
4+
5+
# This is used if we vcat a Banded Jacobian with a Diagonal Matrix in Levenberg
6+
@inline NonlinearSolve._vcat(B::BandedMatrix, D::Diagonal) = vcat(sparse(B), D)
7+
8+
end

src/NonlinearSolve.jl

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,29 @@ if isdefined(Base, :Experimental) && isdefined(Base.Experimental, Symbol("@max_m
44
@eval Base.Experimental.@max_methods 1
55
end
66

7-
using DiffEqBase, LinearAlgebra, LinearSolve, SparseArrays, SparseDiffTools
8-
import ArrayInterface: restructure
9-
import ForwardDiff
10-
11-
import ADTypes: AbstractFiniteDifferencesMode
12-
import ArrayInterface: undefmatrix, matrix_colors, parameterless_type, ismutable, issingular
13-
import ConcreteStructs: @concrete
14-
import EnumX: @enumx
15-
import ForwardDiff: Dual
16-
import LinearSolve: ComposePreconditioner, InvPreconditioner, needs_concrete_A
17-
import RecursiveArrayTools: ArrayPartition,
18-
AbstractVectorOfArray, recursivecopy!, recursivefill!, recursive_unitless_bottom_eltype
197
import Reexport: @reexport
20-
import SciMLBase: AbstractNonlinearAlgorithm, NLStats, _unwrap_val, has_jac, isinplace
21-
import StaticArraysCore: StaticArray, SVector, SArray, MArray
22-
import UnPack: @unpack
8+
import PrecompileTools
9+
10+
PrecompileTools.@recompile_invalidations begin
11+
using DiffEqBase, LinearAlgebra, LinearSolve, SparseArrays, SparseDiffTools
12+
import ArrayInterface: restructure
13+
14+
import ADTypes: AbstractFiniteDifferencesMode
15+
import ArrayInterface: undefmatrix,
16+
matrix_colors, parameterless_type, ismutable, issingular
17+
import ConcreteStructs: @concrete
18+
import EnumX: @enumx
19+
import ForwardDiff
20+
import ForwardDiff: Dual
21+
import LinearSolve: ComposePreconditioner, InvPreconditioner, needs_concrete_A
22+
import RecursiveArrayTools: ArrayPartition,
23+
AbstractVectorOfArray, recursivecopy!, recursivefill!
24+
import SciMLBase: AbstractNonlinearAlgorithm, NLStats, _unwrap_val, has_jac, isinplace
25+
import StaticArraysCore: StaticArray, SVector, SArray, MArray
26+
import UnPack: @unpack
27+
28+
using ADTypes, LineSearches, SciMLBase, SimpleNonlinearSolve
29+
end
2330

2431
@reexport using ADTypes, LineSearches, SciMLBase, SimpleNonlinearSolve
2532

@@ -81,25 +88,17 @@ include("jacobian.jl")
8188
include("ad.jl")
8289
include("default.jl")
8390

84-
import PrecompileTools
85-
86-
@static if VERSION v"1.10-"
87-
PrecompileTools.@compile_workload begin
88-
for T in (Float32, Float64)
89-
prob = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))
90-
91-
precompile_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(),
92-
PseudoTransient(), GeneralBroyden(), GeneralKlement(), nothing)
91+
PrecompileTools.@compile_workload begin
92+
for T in (Float32, Float64)
93+
probs = (NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2)),
94+
NonlinearProblem{false}((u, p) -> u .* u .- p, T[0.1], T[2]),
95+
NonlinearProblem{true}((du, u, p) -> du .= u .* u .- p, T[0.1], T[2]))
9396

94-
for alg in precompile_algs
95-
solve(prob, alg, abstol = T(1e-2))
96-
end
97+
precompile_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(),
98+
PseudoTransient(), GeneralBroyden(), GeneralKlement(), nothing)
9799

98-
prob = NonlinearProblem{true}((du, u, p) -> du[1] = u[1] * u[1] - p[1], T[0.1],
99-
T[2])
100-
for alg in precompile_algs
101-
solve(prob, alg, abstol = T(1e-2))
102-
end
100+
for prob in probs, alg in precompile_algs
101+
solve(prob, alg, abstol = T(1e-2))
103102
end
104103
end
105104
end

src/levenberg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
225225
rhs_tmp = nothing
226226
else
227227
# Preserve Types
228-
mat_tmp = vcat(J, DᵀD)
228+
mat_tmp = _vcat(J, DᵀD)
229229
fill!(mat_tmp, zero(eltype(u)))
230230
rhs_tmp = vcat(_vec(fu1), _vec(u))
231231
fill!(rhs_tmp, zero(eltype(u)))

src/utils.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -314,14 +314,17 @@ function _try_factorize_and_check_singular!(linsolve, X)
314314
end
315315
_try_factorize_and_check_singular!(::Nothing, x) = _issingular(x), false
316316

317-
_reshape(x, args...) = reshape(x, args...)
318-
_reshape(x::Number, args...) = x
317+
@inline _reshape(x, args...) = reshape(x, args...)
318+
@inline _reshape(x::Number, args...) = x
319319

320320
@generated function _axpy!(α, x, y)
321321
hasmethod(axpy!, Tuple{α, x, y}) && return :(axpy!(α, x, y))
322322
return :(@. y += α * x)
323323
end
324324

325-
_needs_square_A(_, ::Number) = true
326-
_needs_square_A(_, ::StaticArray) = true
327-
_needs_square_A(alg, _) = LinearSolve.needs_square_A(alg.linsolve)
325+
@inline _needs_square_A(_, ::Number) = true
326+
@inline _needs_square_A(_, ::StaticArray) = true
327+
@inline _needs_square_A(alg, _) = LinearSolve.needs_square_A(alg.linsolve)
328+
329+
# Define special concatenation for certain Array combinations
330+
@inline _vcat(x, y) = vcat(x, y)

test/GPU/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
[deps]
22
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3+
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
34
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
45

56
[compat]
67
CUDA = "5"
8+
LinearSolve = "2"
79
NonlinearSolve = "2"

test/gpu.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using CUDA, NonlinearSolve
1+
using CUDA, NonlinearSolve, LinearSolve
22

33
CUDA.allowscalar(false)
44

test/misc.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Miscellaneous Tests
2+
using BandedMatrices, LinearAlgebra, NonlinearSolve, SparseArrays, Test
3+
4+
b = BandedMatrix(Ones(5, 5), (1, 1))
5+
d = Diagonal(ones(5, 5))
6+
7+
@test NonlinearSolve._vcat(b, d) == vcat(sparse(b), d)

test/runtests.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,7 @@ end
1717
@time @safetestset "Sparsity Tests" include("sparse.jl")
1818
@time @safetestset "Polyalgs" include("polyalgs.jl")
1919
@time @safetestset "Matrix Resizing" include("matrix_resizing.jl")
20-
if VERSION v"1.10-"
21-
# Takes too long to compile on older versions
22-
@time @safetestset "Nonlinear Least Squares" include("nonlinear_least_squares.jl")
23-
end
20+
@time @safetestset "Nonlinear Least Squares" include("nonlinear_least_squares.jl")
2421
end
2522

2623
if GROUP == "All" || GROUP == "23TestProblems"

0 commit comments

Comments
 (0)