Skip to content

Commit 76556a1

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents 86da8b3 + 5e5aed1 commit 76556a1

17 files changed

+139
-45
lines changed

lib/SimpleNonlinearSolve/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SimpleNonlinearSolve"
22
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
33
authors = ["SciML"]
4-
version = "1.10.1"
4+
version = "1.11.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveChainRulesCoreExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@ module SimpleNonlinearSolveChainRulesCoreExt
22

33
using ChainRulesCore: ChainRulesCore, NoTangent
44
using DiffEqBase: DiffEqBase
5-
using SciMLBase: ChainRulesOriginator, NonlinearProblem, NonlinearLeastSquaresProblem
6-
using SimpleNonlinearSolve: SimpleNonlinearSolve
5+
using SciMLBase: ChainRulesOriginator, NonlinearLeastSquaresProblem
6+
using SimpleNonlinearSolve: SimpleNonlinearSolve, ImmutableNonlinearProblem
77

88
# The expectation here is that no-one is using this directly inside a GPU kernel. We can
99
# eventually lift this requirement using a custom adjoint
1010
function ChainRulesCore.rrule(::typeof(SimpleNonlinearSolve.__internal_solve_up),
11-
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
11+
prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem},
1212
sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...)
1313
out, ∇internal = DiffEqBase._solve_adjoint(
1414
prob, sensealg, u0, p, ChainRulesOriginator(), alg, args...; kwargs...)

lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@ module SimpleNonlinearSolveReverseDiffExt
33
using ArrayInterface: ArrayInterface
44
using DiffEqBase: DiffEqBase
55
using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal
6-
using SciMLBase: ReverseDiffOriginator, NonlinearProblem, NonlinearLeastSquaresProblem
7-
using SimpleNonlinearSolve: SimpleNonlinearSolve
6+
using SciMLBase: ReverseDiffOriginator, NonlinearLeastSquaresProblem
7+
using SimpleNonlinearSolve: SimpleNonlinearSolve, ImmutableNonlinearProblem
88
import SimpleNonlinearSolve: __internal_solve_up
99

10-
for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
10+
for pType in (ImmutableNonlinearProblem, NonlinearLeastSquaresProblem)
1111
@eval begin
1212
function __internal_solve_up(prob::$(pType), sensealg, u0::TrackedArray, u0_changed,
1313
p::TrackedArray, p_changed, alg, args...; kwargs...)

lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
module SimpleNonlinearSolveTrackerExt
22

33
using DiffEqBase: DiffEqBase
4-
using SciMLBase: TrackerOriginator, NonlinearProblem, NonlinearLeastSquaresProblem, remake
5-
using SimpleNonlinearSolve: SimpleNonlinearSolve
4+
using SciMLBase: TrackerOriginator, NonlinearLeastSquaresProblem, remake
5+
using SimpleNonlinearSolve: SimpleNonlinearSolve, ImmutableNonlinearProblem
66
using Tracker: Tracker, TrackedArray
77

8-
for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
8+
for pType in (ImmutableNonlinearProblem, NonlinearLeastSquaresProblem)
99
@eval begin
1010
function SimpleNonlinearSolve.__internal_solve_up(
1111
prob::$(pType), sensealg, u0::TrackedArray,

lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@ using LinearAlgebra: LinearAlgebra, I, convert, copyto!, diagind, dot, issuccess
1919
norm, transpose
2020
using MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex
2121
using Reexport: @reexport
22-
using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
23-
NonlinearFunction, NonlinearLeastSquaresProblem, NonlinearProblem,
24-
ReturnCode, init, remake, solve, AbstractNonlinearAlgorithm,
25-
build_solution, isinplace, _unwrap_val
22+
using SciMLBase: @add_kwonly, SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
23+
AbstractNonlinearFunction, StandardNonlinearProblem, NonlinearFunction,
24+
NonlinearLeastSquaresProblem, NonlinearProblem, ReturnCode, init, remake,
25+
solve, AbstractNonlinearAlgorithm, build_solution, isinplace, _unwrap_val,
26+
warn_paramtype
2627
using Setfield: @set!
2728
using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size
2829

@@ -35,7 +36,7 @@ abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorit
3536
abstract type AbstractNewtonAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end
3637

3738
@inline __is_extension_loaded(::Val) = false
38-
39+
include("immutable_nonlinear_problem.jl")
3940
include("utils.jl")
4041
include("linesearch.jl")
4142

@@ -70,6 +71,19 @@ end
7071
# By Pass the highlevel checks for NonlinearProblem for Simple Algorithms
7172
function SciMLBase.solve(prob::NonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
7273
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
74+
prob = convert(ImmutableNonlinearProblem, prob)
75+
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
76+
sensealg = prob.kwargs[:sensealg]
77+
end
78+
new_u0 = u0 !== nothing ? u0 : prob.u0
79+
new_p = p !== nothing ? p : prob.p
80+
return __internal_solve_up(prob, sensealg, new_u0, u0 === nothing, new_p,
81+
p === nothing, alg, args...; prob.kwargs..., kwargs...)
82+
end
83+
84+
function SciMLBase.solve(
85+
prob::ImmutableNonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
86+
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
7387
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
7488
sensealg = prob.kwargs[:sensealg]
7589
end
@@ -79,8 +93,8 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::AbstractSimpleNonlinearSol
7993
p === nothing, alg, args...; prob.kwargs..., kwargs...)
8094
end
8195

82-
function __internal_solve_up(_prob::NonlinearProblem, sensealg, u0, u0_changed,
83-
p, p_changed, alg, args...; kwargs...)
96+
function __internal_solve_up(_prob::ImmutableNonlinearProblem, sensealg, u0,
97+
u0_changed, p, p_changed, alg, args...; kwargs...)
8498
prob = u0_changed || p_changed ? remake(_prob; u0, p) : _prob
8599
return SciMLBase.__solve(prob, alg, args...; kwargs...)
86100
end

lib/SimpleNonlinearSolve/src/ad.jl

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,26 @@
1-
for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
2-
@eval function SciMLBase.solve(
3-
prob::$(pType){<:Union{Number, <:AbstractArray}, iip,
4-
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
5-
alg::AbstractSimpleNonlinearSolveAlgorithm,
6-
args...;
7-
kwargs...) where {T, V, P, iip}
8-
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
9-
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
10-
return SciMLBase.build_solution(
11-
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
12-
end
1+
function SciMLBase.solve(
2+
prob::NonlinearLeastSquaresProblem{<:Union{Number, <:AbstractArray}, iip,
3+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
4+
alg::AbstractSimpleNonlinearSolveAlgorithm,
5+
args...;
6+
kwargs...) where {T, V, P, iip}
7+
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
8+
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
9+
return SciMLBase.build_solution(
10+
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
11+
end
12+
13+
function SciMLBase.solve(
14+
prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip,
15+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
16+
alg::AbstractSimpleNonlinearSolveAlgorithm,
17+
args...;
18+
kwargs...) where {T, V, P, iip}
19+
prob = convert(ImmutableNonlinearProblem, prob)
20+
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
21+
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
22+
return SciMLBase.build_solution(
23+
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
1324
end
1425

1526
for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
@@ -31,7 +42,8 @@ for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
3142
end
3243

3344
function __nlsolve_ad(
34-
prob::Union{IntervalNonlinearProblem, NonlinearProblem}, alg, args...; kwargs...)
45+
prob::Union{IntervalNonlinearProblem, NonlinearProblem, ImmutableNonlinearProblem},
46+
alg, args...; kwargs...)
3547
p = value(prob.p)
3648
if prob isa IntervalNonlinearProblem
3749
tspan = value.(prob.tspan)
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
struct ImmutableNonlinearProblem{uType, isinplace, P, F, K, PT} <:
2+
AbstractNonlinearProblem{uType, isinplace}
3+
f::F
4+
u0::uType
5+
p::P
6+
problem_type::PT
7+
kwargs::K
8+
@add_kwonly function ImmutableNonlinearProblem{iip}(
9+
f::AbstractNonlinearFunction{iip}, u0, p = NullParameters(),
10+
problem_type = StandardNonlinearProblem(); kwargs...) where {iip}
11+
if haskey(kwargs, :p)
12+
error("`p` specified as a keyword argument `p = $(kwargs[:p])` to `NonlinearProblem`. This is not supported.")
13+
end
14+
warn_paramtype(p)
15+
new{typeof(u0), iip, typeof(p), typeof(f), typeof(kwargs), typeof(problem_type)}(
16+
f, u0, p, problem_type, kwargs)
17+
end
18+
19+
"""
20+
Define a steady state problem using the given function.
21+
`isinplace` optionally sets whether the function is inplace or not.
22+
This is determined automatically, but not inferred.
23+
"""
24+
function ImmutableNonlinearProblem{iip}(
25+
f, u0, p = NullParameters(); kwargs...) where {iip}
26+
ImmutableNonlinearProblem{iip}(NonlinearFunction{iip}(f), u0, p; kwargs...)
27+
end
28+
end
29+
30+
"""
31+
Define a nonlinear problem using an instance of
32+
[`AbstractNonlinearFunction`](@ref AbstractNonlinearFunction).
33+
"""
34+
function ImmutableNonlinearProblem(
35+
f::AbstractNonlinearFunction, u0, p = NullParameters(); kwargs...)
36+
ImmutableNonlinearProblem{isinplace(f)}(f, u0, p; kwargs...)
37+
end
38+
39+
function ImmutableNonlinearProblem(f, u0, p = NullParameters(); kwargs...)
40+
ImmutableNonlinearProblem(NonlinearFunction(f), u0, p; kwargs...)
41+
end
42+
43+
"""
44+
Define a ImmutableNonlinearProblem problem from SteadyStateProblem
45+
"""
46+
function ImmutableNonlinearProblem(prob::AbstractNonlinearProblem)
47+
ImmutableNonlinearProblem{isinplace(prob)}(prob.f, prob.u0, prob.p)
48+
end
49+
50+
function Base.convert(
51+
::Type{ImmutableNonlinearProblem}, prob::T) where {T <: NonlinearProblem}
52+
ImmutableNonlinearProblem{isinplace(prob)}(
53+
prob.f, prob.u0, prob.p, prob.problem_type; prob.kwargs...)
54+
end
55+
56+
function DiffEqBase.get_concrete_problem(
57+
prob::ImmutableNonlinearProblem, isadapt; kwargs...)
58+
u0 = DiffEqBase.get_concrete_u0(prob, isadapt, nothing, kwargs)
59+
u0 = DiffEqBase.promote_u0(u0, prob.p, nothing)
60+
p = DiffEqBase.get_concrete_p(prob, kwargs)
61+
DiffEqBase.remake(prob; u0 = u0, p = p)
62+
end

lib/SimpleNonlinearSolve/src/nlsolve/broyden.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ end
2222

2323
__get_linesearch(::SimpleBroyden{LS}) where {LS} = Val(LS)
2424

25-
function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleBroyden, args...;
25+
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleBroyden, args...;
2626
abstol = nothing, reltol = nothing, maxiters = 1000,
2727
alias_u0 = false, termination_condition = nothing, kwargs...)
2828
x = __maybe_unaliased(prob.u0, alias_u0)

lib/SimpleNonlinearSolve/src/nlsolve/dfsane.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ function SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real =
5454
σ_min, σ_max, σ_1, γ, τ_min, τ_max, nexp, η_strategy)
5555
end
5656

57-
function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...;
57+
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleDFSane{M}, args...;
5858
abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false,
5959
termination_condition = nothing, kwargs...) where {M}
6060
x = __maybe_unaliased(prob.u0, alias_u0)

lib/SimpleNonlinearSolve/src/nlsolve/halley.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ A low-overhead implementation of Halley's Method.
2424
autodiff = nothing
2525
end
2626

27-
function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...;
27+
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleHalley, args...;
2828
abstol = nothing, reltol = nothing, maxiters = 1000,
2929
alias_u0 = false, termination_condition = nothing, kwargs...)
3030
x = __maybe_unaliased(prob.u0, alias_u0)

0 commit comments

Comments
 (0)