Skip to content

Commit 518e53d

Browse files
committed
refactor: minor cleanup
1 parent ba054b7 commit 518e53d

File tree

11 files changed

+120
-77
lines changed

11 files changed

+120
-77
lines changed

common/nlls_problem_workloads.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,23 @@
1+
using SciMLBase: NonlinearLeastSquaresProblem, NonlinearFunction
12

3+
nonlinear_functions = (
4+
(NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), [0.1, 0.0]),
5+
(NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)), [0.1, 0.1]),
6+
(
7+
NonlinearFunction{true}(
8+
(du, u, p) -> du[1] = u[1] * u[1] - p, resid_prototype = zeros(1)
9+
),
10+
[0.1, 0.0]
11+
),
12+
(
13+
NonlinearFunction{true}(
14+
(du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p), resid_prototype = zeros(4)
15+
),
16+
[0.1, 0.1]
17+
)
18+
)
19+
20+
nlls_problems = NonlinearLeastSquaresProblem[]
21+
for (fn, u0) in nonlinear_functions
22+
push!(nlls_problems, NonlinearLeastSquaresProblem(fn, u0, 2.0))
23+
end

lib/BracketingNonlinearSolve/Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
name = "BracketingNonlinearSolve"
22
uuid = "70df07ce-3d50-431d-a3e7-ca6ddb60ac1e"
33
authors = ["Avik Pal <[email protected]> and contributors"]
4-
version = "1.0.0"
4+
version = "1.1.0"
55

66
[deps]
77
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
88
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
99
NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
1010
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
11+
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1112
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1213

1314
[weakdeps]
@@ -25,6 +26,7 @@ ForwardDiff = "0.10.36"
2526
InteractiveUtils = "<0.0.1, 1"
2627
NonlinearSolveBase = "1"
2728
PrecompileTools = "1.2"
29+
Reexport = "1.2"
2830
SciMLBase = "2.50"
2931
Test = "1.10"
3032
TestItemRunner = "1"

lib/BracketingNonlinearSolve/src/BracketingNonlinearSolve.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module BracketingNonlinearSolve
22

33
using ConcreteStructs: @concrete
4+
using Reexport: @reexport
45

56
using CommonSolve: CommonSolve, solve
67
using NonlinearSolveBase: NonlinearSolveBase
@@ -30,7 +31,8 @@ end
3031
@setup_workload begin
3132
for T in (Float32, Float64)
3233
prob_brack = IntervalNonlinearProblem{false}(
33-
(u, p) -> u^2 - p, T.((0.0, 2.0)), T(2))
34+
(u, p) -> u^2 - p, T.((0.0, 2.0)), T(2)
35+
)
3436
algs = (Alefeld(), Bisection(), Brent(), Falsi(), ITP(), Ridder())
3537

3638
@compile_workload begin
@@ -41,8 +43,7 @@ end
4143
end
4244
end
4345

44-
export IntervalNonlinearProblem
45-
export solve
46+
@reexport using SciMLBase, NonlinearSolveBase
4647

4748
export Alefeld, Bisection, Brent, Falsi, ITP, Ridder
4849

lib/NonlinearSolveFirstOrder/Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
1919
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
2020
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2121
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
22-
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
2322
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2423
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
2524

@@ -46,7 +45,6 @@ PrecompileTools = "1.2"
4645
ReTestItems = "1.24"
4746
Reexport = "1"
4847
SciMLBase = "2.54"
49-
SciMLOperators = "0.3.11"
5048
Setfield = "1.1.1"
5149
StableRNGs = "1"
5250
StaticArraysCore = "1.4.3"

lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,20 @@ using ConcreteStructs: @concrete
1010
using DiffEqBase: DiffEqBase # Needed for `init` / `solve` dispatches
1111
using FiniteDiff: FiniteDiff # Default Finite Difference Method
1212
using ForwardDiff: ForwardDiff # Default Forward Mode AD
13-
using LinearAlgebra: LinearAlgebra, Diagonal, dot, inv, diag
13+
using LinearAlgebra: LinearAlgebra, Diagonal, dot
1414
using LinearSolve: LinearSolve # Trigger Linear Solve extension in NonlinearSolveBase
1515
using MaybeInplace: @bb
1616
using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearSolveAlgorithm,
17-
AbstractNonlinearSolveCache, AbstractResetCondition,
18-
AbstractResetConditionCache, AbstractApproximateJacobianStructure,
19-
AbstractJacobianCache, AbstractJacobianInitialization,
20-
AbstractApproximateJacobianUpdateRule, AbstractDescentDirection,
21-
AbstractApproximateJacobianUpdateRuleCache,
22-
AbstractDampingFunction, AbstractDampingFunctionCache,
23-
AbstractTrustRegionMethod, AbstractTrustRegionMethodCache,
17+
AbstractNonlinearSolveCache, AbstractDampingFunction,
18+
AbstractDampingFunctionCache, AbstractTrustRegionMethod,
19+
AbstractTrustRegionMethodCache,
2420
Utils, InternalAPI, get_timer_output, @static_timeit,
2521
update_trace!, L2_NORM,
26-
NewtonDescent, DampedNewtonDescent
27-
using SciMLBase: SciMLBase, AbstractNonlinearProblem, NLStats, ReturnCode
28-
using SciMLOperators: AbstractSciMLOperator
22+
NewtonDescent, DampedNewtonDescent, GeodesicAcceleration,
23+
Dogleg
24+
using SciMLBase: SciMLBase, AbstractNonlinearProblem, NLStats, ReturnCode, NonlinearFunction
2925
using Setfield: @set!
30-
using StaticArraysCore: StaticArray, SArray, Size, MArray
26+
using StaticArraysCore: SArray
3127

3228
include("raphson.jl")
3329
include("gauss_newton.jl")
@@ -37,6 +33,29 @@ include("pseudo_transient.jl")
3733

3834
include("solve.jl")
3935

36+
@setup_workload begin
37+
include(joinpath(
38+
@__DIR__, "..", "..", "..", "common", "nonlinear_problem_workloads.jl"
39+
))
40+
include(joinpath(
41+
@__DIR__, "..", "..", "..", "common", "nlls_problem_workloads.jl"
42+
))
43+
44+
# XXX: TrustRegion
45+
nlp_algs = [NewtonRaphson(), LevenbergMarquardt()]
46+
nlls_algs = [GaussNewton(), LevenbergMarquardt()]
47+
48+
@compile_workload begin
49+
for prob in nonlinear_problems, alg in nlp_algs
50+
CommonSolve.solve(prob, alg; abstol = 1e-2, verbose = false)
51+
end
52+
53+
for prob in nlls_problems, alg in nlls_algs
54+
CommonSolve.solve(prob, alg; abstol = 1e-2, verbose = false)
55+
end
56+
end
57+
end
58+
4059
@reexport using SciMLBase, NonlinearSolveBase
4160

4261
export NewtonRaphson, PseudoTransient

lib/NonlinearSolveFirstOrder/src/levenberg_marquardt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,10 @@ end
131131
function InternalAPI.solve!(
132132
cache::LevenbergMarquardtDampingCache, J, fu, ::Val{false}; kwargs...
133133
)
134-
if ArrayInterface.can_setindex(cache.J_diag_cache)
135-
sum!(abs2, Utils.safe_vec(cache.J_diag_cache), J')
136-
elseif cache.J_diag_cache isa Number
134+
if cache.J_diag_cache isa Number
137135
cache.J_diag_cache = abs2(J)
136+
elseif ArrayInterface.can_setindex(cache.J_diag_cache)
137+
sum!(abs2, Utils.safe_vec(cache.J_diag_cache), J')
138138
else
139139
cache.J_diag_cache = dropdims(sum(abs2, J'; dims = 1); dims = 1)
140140
end
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,40 @@
1+
"""
2+
TrustRegion(;
3+
concrete_jac = nothing, linsolve = nothing, precs = nothing,
4+
radius_update_scheme = RadiusUpdateSchemes.Simple, max_trust_radius::Real = 0 // 1,
5+
initial_trust_radius::Real = 0 // 1, step_threshold::Real = 1 // 10000,
6+
shrink_threshold::Real = 1 // 4, expand_threshold::Real = 3 // 4,
7+
shrink_factor::Real = 1 // 4, expand_factor::Real = 2 // 1,
8+
max_shrink_times::Int = 32,
9+
vjp_autodiff = nothing, autodiff = nothing, jvp_autodiff = nothing
10+
)
111
12+
An advanced TrustRegion implementation with support for efficient handling of sparse
13+
matrices via colored automatic differentiation and preconditioned linear solvers. Designed
14+
for large-scale and numerically-difficult nonlinear systems.
15+
16+
### Keyword Arguments
17+
18+
- `radius_update_scheme`: the scheme used to update the trust region radius. Defaults to
19+
`RadiusUpdateSchemes.Simple`. See [`RadiusUpdateSchemes`](@ref) for more details. For a
20+
review on trust region radius update schemes, see [yuan2015recent](@citet).
21+
22+
For the remaining arguments, see [`NonlinearSolve.GenericTrustRegionScheme`](@ref)
23+
documentation.
24+
"""
25+
function TrustRegion(;
26+
concrete_jac = nothing, linsolve = nothing, precs = nothing,
27+
radius_update_scheme = RadiusUpdateSchemes.Simple, max_trust_radius::Real = 0 // 1,
28+
initial_trust_radius::Real = 0 // 1, step_threshold::Real = 1 // 10000,
29+
shrink_threshold::Real = 1 // 4, expand_threshold::Real = 3 // 4,
30+
shrink_factor::Real = 1 // 4, expand_factor::Real = 2 // 1,
31+
max_shrink_times::Int = 32,
32+
autodiff = nothing, vjp_autodiff = nothing, jvp_autodiff = nothing
33+
)
34+
descent = Dogleg(; linsolve, precs)
35+
trustregion = GenericTrustRegionScheme(;
36+
method = radius_update_scheme, step_threshold, shrink_threshold, expand_threshold,
37+
shrink_factor, expand_factor, initial_trust_radius, max_trust_radius)
38+
return GeneralizedFirstOrderAlgorithm{concrete_jac, :TrustRegion}(;
39+
trustregion, descent, autodiff, vjp_autodiff, jvp_autodiff, max_shrink_times)
40+
end

lib/NonlinearSolveQuasiNewton/src/NonlinearSolveQuasiNewton.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,8 @@ include("solve.jl")
4040
algs = [Broyden(), Klement()]
4141

4242
@compile_workload begin
43-
@sync begin
44-
for prob in nonlinear_problems, alg in algs
45-
Threads.@spawn CommonSolve.solve(prob, alg; abstol = 1e-2, verbose = false)
46-
end
43+
for prob in nonlinear_problems, alg in algs
44+
CommonSolve.solve(prob, alg; abstol = 1e-2, verbose = false)
4745
end
4846
end
4947
end

lib/NonlinearSolveSpectralMethods/src/NonlinearSolveSpectralMethods.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,8 @@ include("solve.jl")
2626
algs = [DFSane()]
2727

2828
@compile_workload begin
29-
@sync begin
30-
for prob in nonlinear_problems, alg in algs
31-
Threads.@spawn CommonSolve.solve(prob, alg; abstol = 1e-2, verbose = false)
32-
end
29+
for prob in nonlinear_problems, alg in algs
30+
CommonSolve.solve(prob, alg; abstol = 1e-2, verbose = false)
3331
end
3432
end
3533
end

lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,8 @@ function solve_adjoint_internal end
130130
#!format: on
131131

132132
@compile_workload begin
133-
@sync for alg in algs
134-
for prob in (prob_scalar, prob_iip, prob_oop)
135-
Threads.@spawn CommonSolve.solve(prob, alg; abstol = 1e-2)
136-
end
133+
for prob in (prob_scalar, prob_iip, prob_oop)
134+
CommonSolve.solve(prob, alg; abstol = 1e-2)
137135
end
138136
end
139137
end

0 commit comments

Comments
 (0)