Skip to content

Commit c5ab344

Browse files
committed
Remove DiffEqBase Dependency
1 parent b1373d3 commit c5ab344

File tree

6 files changed

+242
-16
lines changed

6 files changed

+242
-16
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,18 @@ authors = ["Kanav Gupta <[email protected]>"]
44
version = "0.1.0"
55

66
[deps]
7-
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
87
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
98
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
9+
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
1010
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1111
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
1212
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1313
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
1414

1515
[extras]
1616
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
17-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1817
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
18+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1919

2020
[targets]
21-
test = ["BenchmarkTools", "Test", "ForwardDiff"]
21+
test = ["BenchmarkTools", "Test", "ForwardDiff"]

src/NonlinearSolve.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
11
module NonlinearSolve
22

33
using Reexport
4-
@reexport using DiffEqBase
54
using UnPack: @unpack
65
using FiniteDiff, ForwardDiff
76
using Setfield
87
using StaticArrays
8+
using RecursiveArrayTools
99

10+
abstract type AbstractNonlinearProblem{uType,isinplace} end
1011
abstract type AbstractNonlinearSolveAlgorithm end
1112
abstract type AbstractBracketingAlgorithm <: AbstractNonlinearSolveAlgorithm end
1213
abstract type AbstractNewtonAlgorithm{CS,AD} <: AbstractNonlinearSolveAlgorithm end
1314
abstract type AbstractNonlinearSolver end
1415
abstract type AbstractImmutableNonlinearSolver <: AbstractNonlinearSolver end
1516

17+
include("utils.jl")
1618
include("jacobian.jl")
1719
include("types.jl")
18-
include("utils.jl")
1920
include("solve.jl")
2021
include("bisection.jl")
2122
include("falsi.jl")
@@ -28,5 +29,9 @@ module NonlinearSolve
2829
# DiffEq styled algorithms
2930
export Bisection, Falsi, NewtonRaphson
3031

32+
export NonlinearProblem
33+
34+
export solve, init, solve!
35+
3136
export reinit!
3237
end # module

src/scalar.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
function DiffEqBase.solve(prob::NonlinearProblem{<:Number}, ::NewtonRaphson, args...; xatol = nothing, xrtol = nothing, maxiters = 1000, kwargs...)
1+
function solve(prob::NonlinearProblem{<:Number}, ::NewtonRaphson, args...; xatol = nothing, xrtol = nothing, maxiters = 1000, kwargs...)
22
f = Base.Fix2(prob.f, prob.p)
33
x = float(prob.u0)
44
T = typeof(x)
@@ -19,7 +19,7 @@ function DiffEqBase.solve(prob::NonlinearProblem{<:Number}, ::NewtonRaphson, arg
1919
return NewtonSolution(x, MAXITERS_EXCEED)
2020
end
2121

22-
function DiffEqBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters = 1000, kwargs...)
22+
function solve(prob::NonlinearProblem, ::Bisection, args...; maxiters = 1000, kwargs...)
2323
f = Base.Fix2(prob.f, prob.p)
2424
left, right = prob.u0
2525
fl, fr = f(left), f(right)

src/solve.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
function DiffEqBase.solve(prob::NonlinearProblem,
1+
function solve(prob::NonlinearProblem,
22
alg::AbstractNonlinearSolveAlgorithm, args...;
33
kwargs...)
4-
solver = DiffEqBase.init(prob, alg, args...; kwargs...)
4+
solver = init(prob, alg, args...; kwargs...)
55
sol = solve!(solver)
66
return sol
77
end
88

9-
function DiffEqBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracketingAlgorithm, args...;
9+
function init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracketingAlgorithm, args...;
1010
alias_u0 = false,
1111
maxiters = 1000,
1212
kwargs...
@@ -33,11 +33,11 @@ function DiffEqBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracke
3333
return BracketingImmutableSolver(1, f, alg, left, right, fl, fr, p, false, maxiters, DEFAULT, cache, iip)
3434
end
3535

36-
function DiffEqBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewtonAlgorithm, args...;
36+
function init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewtonAlgorithm, args...;
3737
alias_u0 = false,
3838
maxiters = 1000,
3939
tol = 1e-6,
40-
internalnorm = Base.Fix2(DiffEqBase.ODE_DEFAULT_NORM, nothing),
40+
internalnorm = DEFAULT_NORM,
4141
kwargs...
4242
) where {uType, iip}
4343

@@ -58,7 +58,7 @@ function DiffEqBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewton
5858
return NewtonImmutableSolver(1, f, alg, u, fu, p, false, maxiters, internalnorm, DEFAULT, tol, cache, iip)
5959
end
6060

61-
function DiffEqBase.solve!(solver::AbstractImmutableNonlinearSolver)
61+
function solve!(solver::AbstractImmutableNonlinearSolver)
6262
solver = mic_check(solver)
6363
while !solver.force_stop && solver.iter < solver.maxiters
6464
solver = perform_step(solver, solver.alg, Val(solver.iip))

src/types.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
struct NullParameters end
2+
3+
struct NonlinearProblem{uType,isinplace,P,F,K} <: AbstractNonlinearProblem{uType,isinplace}
4+
f::F
5+
u0::uType
6+
p::P
7+
kwargs::K
8+
@add_kwonly function NonlinearProblem{iip}(f,u0,p=NullParameters();kwargs...) where iip
9+
new{typeof(u0),iip,typeof(p),typeof(f),typeof(kwargs)}(f,u0,p,kwargs)
10+
end
11+
end
12+
13+
NonlinearProblem(f,u0,args...;kwargs...) = NonlinearProblem{isinplace(f, 3)}(f,u0,args...;kwargs...)
14+
115
@enum Retcode::Int begin
216
DEFAULT
317
EXACT_SOLUTION_LEFT

src/utils.jl

Lines changed: 210 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,213 @@
1+
"""
2+
@add_kwonly function_definition
3+
4+
Define keyword-only version of the `function_definition`.
5+
6+
@add_kwonly function f(x; y=1)
7+
...
8+
end
9+
10+
expands to:
11+
12+
function f(x; y=1)
13+
...
14+
end
15+
function f(; x = error("No argument x"), y=1)
16+
...
17+
end
18+
"""
19+
macro add_kwonly(ex)
20+
esc(add_kwonly(ex))
21+
end
22+
23+
add_kwonly(ex::Expr) = add_kwonly(Val{ex.head}, ex)
24+
25+
function add_kwonly(::Type{<: Val}, ex)
26+
error("add_only does not work with expression $(ex.head)")
27+
end
28+
29+
function add_kwonly(::Union{Type{Val{:function}},
30+
Type{Val{:(=)}}}, ex::Expr)
31+
body = ex.args[2:end] # function body
32+
default_call = ex.args[1] # e.g., :(f(a, b=2; c=3))
33+
kwonly_call = add_kwonly(default_call)
34+
if kwonly_call === nothing
35+
return ex
36+
end
37+
38+
return quote
39+
begin
40+
$ex
41+
$(Expr(ex.head, kwonly_call, body...))
42+
end
43+
end
44+
end
45+
46+
function add_kwonly(::Type{Val{:where}}, ex::Expr)
47+
default_call = ex.args[1]
48+
rest = ex.args[2:end]
49+
kwonly_call = add_kwonly(default_call)
50+
if kwonly_call === nothing
51+
return nothing
52+
end
53+
return Expr(:where, kwonly_call, rest...)
54+
end
55+
56+
function add_kwonly(::Type{Val{:call}}, default_call::Expr)
57+
# default_call is, e.g., :(f(a, b=2; c=3))
58+
funcname = default_call.args[1] # e.g., :f
59+
required = [] # required positional arguments; e.g., [:a]
60+
optional = [] # optional positional arguments; e.g., [:(b=2)]
61+
default_kwargs = []
62+
for arg in default_call.args[2:end]
63+
if isa(arg, Symbol)
64+
push!(required, arg)
65+
elseif arg.head == :(::)
66+
push!(required, arg)
67+
elseif arg.head == :kw
68+
push!(optional, arg)
69+
elseif arg.head == :parameters
70+
@assert default_kwargs == [] # can I have :parameters twice?
71+
default_kwargs = arg.args
72+
else
73+
error("Not expecting to see: $arg")
74+
end
75+
end
76+
if isempty(required) && isempty(optional)
77+
# If the function is already keyword-only, do nothing:
78+
return nothing
79+
end
80+
if isempty(required)
81+
# It's not clear what should be done. Let's not support it at
82+
# the moment:
83+
error("At least one positional mandatory argument is required.")
84+
end
85+
86+
kwonly_kwargs = Expr(:parameters, [
87+
Expr(:kw, pa, :(error($("No argument $pa"))))
88+
for pa in required
89+
]..., optional..., default_kwargs...)
90+
kwonly_call = Expr(:call, funcname, kwonly_kwargs)
91+
# e.g., :(f(; a=error(...), b=error(...), c=1, d=2))
92+
93+
return kwonly_call
94+
end
95+
96+
function num_types_in_tuple(sig)
97+
length(sig.parameters)
98+
end
99+
100+
function num_types_in_tuple(sig::UnionAll)
101+
length(Base.unwrap_unionall(sig).parameters)
102+
end
103+
104+
function numargs(f)
105+
typ = Tuple{Any, Val{:analytic}, Vararg}
106+
typ2 = Tuple{Any, Type{Val{:analytic}}, Vararg} # This one is required for overloaded types
107+
typ3 = Tuple{Any, Val{:jac}, Vararg}
108+
typ4 = Tuple{Any, Type{Val{:jac}}, Vararg} # This one is required for overloaded types
109+
typ5 = Tuple{Any, Val{:tgrad}, Vararg}
110+
typ6 = Tuple{Any, Type{Val{:tgrad}}, Vararg} # This one is required for overloaded types
111+
numparam = maximum([(m.sig<:typ || m.sig<:typ2 || m.sig<:typ3 || m.sig<:typ4 || m.sig<:typ5 || m.sig<:typ6) ? 0 : num_types_in_tuple(m.sig) for m in methods(f)])
112+
return (numparam-1) #-1 in v0.5 since it adds f as the first parameter
113+
end
114+
115+
function isinplace(f,inplace_param_number)
116+
numargs(f)>=inplace_param_number
117+
end
118+
119+
### Default Linsolve
120+
121+
# Try to be as smart as possible
122+
# lu! if Matrix
123+
# lu if sparse
124+
# gmres if operator
125+
126+
mutable struct DefaultLinSolve
127+
A
128+
iterable
129+
end
130+
DefaultLinSolve() = DefaultLinSolve(nothing, nothing)
131+
132+
function (p::DefaultLinSolve)(x,A,b,update_matrix=false;tol=nothing, kwargs...)
133+
if p.iterable isa Vector && eltype(p.iterable) <: LinearAlgebra.BlasInt # `iterable` here is the pivoting vector
134+
F = LU{eltype(A)}(A, p.iterable, zero(LinearAlgebra.BlasInt))
135+
ldiv!(x, F, b)
136+
return nothing
137+
end
138+
if update_matrix
139+
if typeof(A) <: Matrix
140+
blasvendor = BLAS.vendor()
141+
# if the user doesn't use OpenBLAS, we assume that is a better BLAS
142+
# implementation like MKL
143+
#
144+
# RecursiveFactorization seems to be consistantly winning below 100
145+
# https://discourse.julialang.org/t/ann-recursivefactorization-jl/39213
146+
if ArrayInterface.can_setindex(x) && (size(A,1) <= 100 || ((blasvendor === :openblas || blasvendor === :openblas64) && size(A,1) <= 500))
147+
p.A = RecursiveFactorization.lu!(A)
148+
else
149+
p.A = lu!(A)
150+
end
151+
elseif typeof(A) <: Tridiagonal
152+
p.A = lu!(A)
153+
elseif typeof(A) <: Union{SymTridiagonal}
154+
p.A = ldlt!(A)
155+
elseif typeof(A) <: Union{Symmetric,Hermitian}
156+
p.A = bunchkaufman!(A)
157+
elseif typeof(A) <: SparseMatrixCSC
158+
p.A = lu(A)
159+
elseif ArrayInterface.isstructured(A)
160+
p.A = factorize(A)
161+
elseif !(typeof(A) <: AbstractDiffEqOperator)
162+
# Most likely QR is the one that is overloaded
163+
# Works on things like CuArrays
164+
p.A = qr(A)
165+
end
166+
end
167+
168+
if typeof(A) <: Union{Matrix,SymTridiagonal,Tridiagonal,Symmetric,Hermitian} # No 2-arg form for SparseArrays!
169+
x .= b
170+
ldiv!(p.A,x)
171+
# Missing a little bit of efficiency in a rare case
172+
#elseif typeof(A) <: DiffEqArrayOperator
173+
# ldiv!(x,p.A,b)
174+
elseif ArrayInterface.isstructured(A) || A isa SparseMatrixCSC
175+
ldiv!(x,p.A,b)
176+
elseif typeof(A) <: AbstractDiffEqOperator
177+
# No good starting guess, so guess zero
178+
if p.iterable === nothing
179+
p.iterable = IterativeSolvers.gmres_iterable!(x,A,b;initially_zero=true,restart=5,maxiter=5,tol=1e-16,kwargs...)
180+
p.iterable.reltol = tol
181+
end
182+
x .= false
183+
iter = p.iterable
184+
purge_history!(iter, x, b)
185+
186+
for residual in iter
187+
end
188+
else
189+
ldiv!(x,p.A,b)
190+
end
191+
return nothing
192+
end
193+
194+
function (p::DefaultLinSolve)(::Type{Val{:init}},f,u0_prototype)
195+
DefaultLinSolve()
196+
end
197+
198+
const DEFAULT_LINSOLVE = DefaultLinSolve()
199+
200+
@inline UNITLESS_ABS2(x) = real(abs2(x))
201+
@inline DEFAULT_NORM(u::Union{AbstractFloat,Complex}) = @fastmath abs(u)
202+
@inline DEFAULT_NORM(u::Array{T}) where T<:Union{AbstractFloat,Complex} =
203+
sqrt(real(sum(abs2,u)) / length(u))
204+
@inline DEFAULT_NORM(u::StaticArray{T}) where T<:Union{AbstractFloat,Complex} =
205+
sqrt(real(sum(abs2,u)) / length(u))
206+
@inline DEFAULT_NORM(u::RecursiveArrayTools.AbstractVectorOfArray) =
207+
sum(sqrt(real(sum(UNITLESS_ABS2,_u)) / length(_u)) for _u in u.u)
208+
@inline DEFAULT_NORM(u::AbstractArray) = sqrt(real(sum(UNITLESS_ABS2,u)) / length(u))
209+
@inline DEFAULT_NORM(u) = norm(u)
210+
1211
"""
2212
prevfloat_tdir(x, x0, x1)
3213
@@ -24,6 +234,3 @@ function value_derivative(f::F, x::R) where {F,R}
24234
out = f(ForwardDiff.Dual{T}(x, one(x)))
25235
ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out)
26236
end
27-
28-
DiffEqBase.has_Wfact(f::Function) = false
29-
DiffEqBase.has_Wfact_t(f::Function) = false

0 commit comments

Comments
 (0)