Skip to content

Commit 71162f9

Browse files
Merge pull request #25 from JuliaComputing/scimlbase
Extend SciMLBase
2 parents 3398cc4 + 33fa361 commit 71162f9

File tree

6 files changed

+21
-38
lines changed

6 files changed

+21
-38
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
88
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
99
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
1010
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
11+
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1112
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
1213
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1314
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

src/NonlinearSolve.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ using Setfield
88
using StaticArrays
99
using RecursiveArrayTools
1010

11+
@reexport using SciMLBase
12+
1113
abstract type AbstractNonlinearProblem{uType,isinplace} end
1214
abstract type AbstractNonlinearSolveAlgorithm end
1315
abstract type AbstractBracketingAlgorithm <: AbstractNonlinearSolveAlgorithm end
@@ -27,10 +29,4 @@ include("scalar.jl")
2729
# DiffEq styled algorithms
2830
export Bisection, Falsi, NewtonRaphson
2931

30-
export NonlinearProblem
31-
32-
export solve, init, solve!
33-
34-
export reinit!
35-
3632
end # module

src/raphson.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
struct NewtonRaphson{CS, AD, DT, L} <: AbstractNewtonAlgorithm{CS,AD}
1+
struct NewtonRaphson{CS, AD, DT, L} <: AbstractNewtonAlgorithm{CS,AD}
22
diff_type::DT
33
linsolve::L
44
end

src/scalar.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
function solve(prob::NonlinearProblem{<:Number}, alg::NewtonRaphson, args...; xatol = nothing, xrtol = nothing, maxiters = 1000, kwargs...)
1+
function SciMLBase.solve(prob::NonlinearProblem{<:Number}, alg::NewtonRaphson, args...; xatol = nothing, xrtol = nothing, maxiters = 1000, kwargs...)
22
f = Base.Fix2(prob.f, prob.p)
33
x = float(prob.u0)
44
T = typeof(x)
@@ -48,28 +48,28 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
4848
return sol, partials
4949
end
5050

51-
function solve(prob::NonlinearProblem{<:Number, iip, <:Dual{T,V,P}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
51+
function SciMLBase.solve(prob::NonlinearProblem{<:Number, iip, <:Dual{T,V,P}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
5252
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
5353
return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode)
5454
end
55-
function solve(prob::NonlinearProblem{<:Number, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
55+
function SciMLBase.solve(prob::NonlinearProblem{<:Number, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
5656
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
5757
return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode)
5858
end
5959

6060
# avoid ambiguities
6161
for Alg in [Bisection]
62-
@eval function solve(prob::NonlinearProblem{uType, iip, <:Dual{T,V,P}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P}
62+
@eval function SciMLBase.solve(prob::NonlinearProblem{uType, iip, <:Dual{T,V,P}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P}
6363
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
6464
return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode)
6565
end
66-
@eval function solve(prob::NonlinearProblem{uType, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P}
66+
@eval function SciMLBase.solve(prob::NonlinearProblem{uType, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P}
6767
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
6868
return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode)
6969
end
7070
end
7171

72-
function solve(prob::NonlinearProblem, ::Bisection, args...; maxiters = 1000, kwargs...)
72+
function SciMLBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters = 1000, kwargs...)
7373
f = Base.Fix2(prob.f, prob.p)
7474
left, right = prob.u0
7575
fl, fr = f(left), f(right)
@@ -116,7 +116,7 @@ function solve(prob::NonlinearProblem, ::Bisection, args...; maxiters = 1000, kw
116116
return BracketingSolution(left, right, MAXITERS_EXCEED)
117117
end
118118

119-
function solve(prob::NonlinearProblem, ::Falsi, args...; maxiters = 1000, kwargs...)
119+
function SciMLBase.solve(prob::NonlinearProblem, ::Falsi, args...; maxiters = 1000, kwargs...)
120120
f = Base.Fix2(prob.f, prob.p)
121121
left, right = prob.u0
122122
fl, fr = f(left), f(right)

src/solve.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
function solve(prob::NonlinearProblem,
2-
alg::AbstractNonlinearSolveAlgorithm, args...;
3-
kwargs...)
1+
function SciMLBase.solve(prob::NonlinearProblem,
2+
alg::AbstractNonlinearSolveAlgorithm, args...;
3+
kwargs...)
44
solver = init(prob, alg, args...; kwargs...)
55
sol = solve!(solver)
66
return sol
77
end
88

9-
function init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracketingAlgorithm, args...;
9+
function SciMLBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracketingAlgorithm, args...;
1010
alias_u0 = false,
1111
maxiters = 1000,
1212
kwargs...
@@ -33,7 +33,7 @@ function init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracketingAlgorit
3333
return BracketingImmutableSolver(1, f, alg, left, right, fl, fr, p, false, maxiters, DEFAULT, cache, iip)
3434
end
3535

36-
function init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewtonAlgorithm, args...;
36+
function SciMLBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewtonAlgorithm, args...;
3737
alias_u0 = false,
3838
maxiters = 1000,
3939
tol = 1e-6,
@@ -58,7 +58,7 @@ function init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewtonAlgorithm,
5858
return NewtonImmutableSolver(1, f, alg, u, fu, p, false, maxiters, internalnorm, DEFAULT, tol, cache, iip)
5959
end
6060

61-
function solve!(solver::AbstractImmutableNonlinearSolver)
61+
function SciMLBase.solve!(solver::AbstractImmutableNonlinearSolver)
6262
solver = mic_check(solver)
6363
while !solver.force_stop && solver.iter < solver.maxiters
6464
solver = perform_step(solver, solver.alg, Val(solver.iip))
@@ -115,14 +115,14 @@ end
115115
116116
Reinitialize solver to the original starting conditions
117117
"""
118-
function reinit!(solver::NewtonImmutableSolver, prob::NonlinearProblem{uType, true}) where {uType}
118+
function SciMLBase.reinit!(solver::NewtonImmutableSolver, prob::NonlinearProblem{uType, true}) where {uType}
119119
@. solver.u = prob.u0
120120
@set! solver.iter = 1
121121
@set! solver.force_stop = false
122122
return solver
123123
end
124124

125-
function reinit!(solver::NewtonImmutableSolver, prob::NonlinearProblem{uType, false}) where {uType}
125+
function SciMLBase.reinit!(solver::NewtonImmutableSolver, prob::NonlinearProblem{uType, false}) where {uType}
126126
@set! solver.u = prob.u0
127127
@set! solver.iter = 1
128128
@set! solver.force_stop = false

src/types.jl

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,3 @@
1-
struct NullParameters end
2-
3-
struct NonlinearProblem{uType,isinplace,P,F,K} <: AbstractNonlinearProblem{uType,isinplace}
4-
f::F
5-
u0::uType
6-
p::P
7-
kwargs::K
8-
@add_kwonly function NonlinearProblem{iip}(f,u0,p=NullParameters();kwargs...) where iip
9-
new{typeof(u0),iip,typeof(p),typeof(f),typeof(kwargs)}(f,u0,p,kwargs)
10-
end
11-
end
12-
13-
NonlinearProblem(f,u0,args...;kwargs...) = NonlinearProblem{isinplace(f, 3)}(f,u0,args...;kwargs...)
14-
151
@enum Retcode::Int begin
162
DEFAULT
173
EXACT_SOLUTION_LEFT
@@ -37,7 +23,7 @@ struct BracketingImmutableSolver{fType, algType, uType, resType, pType, cacheTyp
3723
end
3824

3925
# function BracketingImmutableSolver(iip, iter, f, alg, left, right, fl, fr, p, force_stop, maxiters, retcode, cache)
40-
# BracketingImmutableSolver{iip, typeof(f), typeof(alg),
26+
# BracketingImmutableSolver{iip, typeof(f), typeof(alg),
4127
# typeof(left), typeof(fl), typeof(p), typeof(cache)}(iter, f, alg, left, right, fl, fr, p, force_stop, maxiters, retcode, cache)
4228
# end
4329

@@ -58,7 +44,7 @@ struct NewtonImmutableSolver{fType, algType, uType, resType, pType, INType, tolT
5844
end
5945

6046
# function NewtonImmutableSolver{iip}(iter, f, alg, u, fu, p, force_stop, maxiters, internalnorm, retcode, tol, cache) where iip
61-
# NewtonImmutableSolver{iip, typeof(f), typeof(alg), typeof(u),
47+
# NewtonImmutableSolver{iip, typeof(f), typeof(alg), typeof(u),
6248
# typeof(fu), typeof(p), typeof(internalnorm), typeof(tol), typeof(cache)}(iter, f, alg, u, fu, p, force_stop, maxiters, internalnorm, retcode, tol, cache)
6349
# end
6450

0 commit comments

Comments
 (0)