Skip to content

Commit 628ff57

Browse files
committed
impose function syntax
1 parent b7e8667 commit 628ff57

File tree

4 files changed

+54
-113
lines changed

4 files changed

+54
-113
lines changed

src/LinearSolve.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@ using Reexport
2626
abstract type SciMLLinearSolveAlgorithm <: SciMLBase.AbstractLinearAlgorithm end
2727
abstract type AbstractFactorization <: SciMLLinearSolveAlgorithm end
2828
abstract type AbstractKrylovSubspaceMethod <: SciMLLinearSolveAlgorithm end
29-
abstract type AbstractFunctionCall <: SciMLLinearSolveAlgorithm end
29+
abstract type AbstractSolveFunction <: SciMLLinearSolveAlgorithm end
3030

3131
# Traits
3232

3333
needs_concrete_A(alg::AbstractFactorization) = true
3434
needs_concrete_A(alg::AbstractKrylovSubspaceMethod) = false
35-
needs_concrete_A(alg::AbstractFunctionCall) = false
35+
needs_concrete_A(alg::AbstractSolveFunction) = false
3636

3737
# Code
3838

@@ -41,7 +41,7 @@ include("factorization.jl")
4141
include("simplelu.jl")
4242
include("iterative_wrappers.jl")
4343
include("preconditioners.jl")
44-
include("function_call.jl")
44+
include("solve_function.jl")
4545
include("default.jl")
4646
include("init.jl")
4747

@@ -52,7 +52,7 @@ export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
5252
GenericLUFactorization, SimpleLUFactorization, RFLUFactorization,
5353
UMFPACKFactorization, KLUFactorization
5454

55-
export FunctionCall, ApplyLDivBang, ApplyLDivBang2Args, ApplyLDivBang3Args
55+
export LinearSolveFunction
5656

5757
export KrylovJL, KrylovJL_CG, KrylovJL_GMRES, KrylovJL_BICGSTAB, KrylovJL_MINRES,
5858
IterativeSolversJL, IterativeSolversJL_CG, IterativeSolversJL_GMRES,

src/function_call.jl

Lines changed: 0 additions & 49 deletions
This file was deleted.

src/solve_function.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#
2+
function DEFAULT_LINEAR_SOLVE(A,b,u,p,newA,Pl,Pr,solverdata;kwargs...)
3+
solve(LinearProblem(A, b; u0=u); p=p, kwargs...).u
4+
end
5+
6+
Base.@kwdef struct LinearSolveFunction{F} <: AbstractSolveFunction
7+
solve_func::F = DEFAULT_LINEAR_SOLVE
8+
end
9+
10+
function SciMLBase.solve(cache::LinearCache, alg::LinearSolveFunction,
11+
args...; kwargs...)
12+
@unpack A,b,u,p,isfresh,Pl,Pr,cacheval = cache
13+
@unpack solve_func = alg
14+
15+
u = solve_func(A,b,u,p,isfresh,Pl,Pr,cacheval;kwargs...)
16+
cache = set_u(cache, u)
17+
18+
return SciMLBase.build_linear_solution(alg,cache.u,nothing,cache)
19+
end

test/basictests.jl

Lines changed: 31 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -37,65 +37,6 @@ end
3737

3838
@testset "LinearSolve" begin
3939

40-
@testset "Apply Function" begin
41-
42-
@testset "Diagonal Type" begin
43-
A1 = rand(n) |> Diagonal; b1 = rand(n); x1 = zero(b1)
44-
A2 = rand(n) |> Diagonal; b2 = rand(n); x2 = zero(b1)
45-
46-
prob1 = LinearProblem(A1, b1; u0=x1)
47-
prob2 = LinearProblem(A1, b1; u0=x1)
48-
49-
for alg in (
50-
FunctionCall(LinearAlgebra.ldiv!, (:u, :A, :b)),
51-
ApplyLDivBang(),
52-
ApplyLDivBang2Args(),
53-
ApplyLDivBang3Args(),
54-
)
55-
test_interface(alg, prob1, prob2)
56-
end
57-
end
58-
59-
@testset "Custom Type" begin
60-
61-
struct MyDiag
62-
d
63-
end
64-
65-
# overloads
66-
(D::MyDiag)(du, u, p, t) = mul!(du, D, u)
67-
Base.:*(D::MyDiag, u) = Diagonal(D.d) * u
68-
69-
Base.copy(D::MyDiag) = copy(D.d) |> MyDiag
70-
71-
LinearAlgebra.mul!(y, D::MyDiag, x) = mul!(y, Diagonal(D.d), x)
72-
LinearAlgebra.ldiv!(y, D::MyDiag, x) = ldiv!(y, Diagonal(D.d), x)
73-
LinearAlgebra.ldiv!(D::MyDiag, x) = ldiv!(Diagonal(D.d), x)
74-
75-
# custom inverse function
76-
function my_inv!(D::MyDiag, u, b)
77-
@. u = b / D.d
78-
end
79-
80-
A1 = rand(n) |> MyDiag; b1 = rand(n); x1 = zero(b1)
81-
A2 = rand(n) |> MyDiag; b2 = rand(n); x2 = zero(b1)
82-
83-
prob1 = LinearProblem(A1, b1; u0=x1)
84-
prob2 = LinearProblem(A1, b1; u0=x1)
85-
86-
for alg in (
87-
FunctionCall(LinearAlgebra.ldiv!, (:u, :A, :b)),
88-
FunctionCall(my_inv!, (:A, :u, :b)),
89-
ApplyLDivBang(),
90-
ApplyLDivBang2Args(),
91-
ApplyLDivBang3Args(),
92-
)
93-
test_interface(alg, prob1, prob2)
94-
end
95-
end
96-
end
97-
#=
98-
9940
@testset "Default Linear Solver" begin
10041
test_interface(nothing, prob1, prob2)
10142

@@ -345,6 +286,36 @@ end
345286
@test sol13.u sol23.u
346287
@test sol13.u sol33.u
347288
end
348-
=#
289+
290+
@testset "Solve Function" begin
291+
292+
A1 = rand(n) |> Diagonal; b1 = rand(n); x1 = zero(b1)
293+
A2 = rand(n) |> Diagonal; b2 = rand(n); x2 = zero(b1)
294+
295+
function sol_func(A,b,u,p,newA,Pl,Pr,solverdata;verbose=true, kwargs...)
296+
if verbose == true
297+
println("out-of-place solve")
298+
end
299+
u = A \ b
300+
end
301+
302+
function sol_func!(A,b,u,p,newA,Pl,Pr,solverdata;verbose=true, kwargs...)
303+
if verbose == true
304+
println("in-place solve")
305+
end
306+
ldiv!(u,A,b)
307+
end
308+
309+
prob1 = LinearProblem(A1, b1; u0=x1)
310+
prob2 = LinearProblem(A1, b1; u0=x1)
311+
312+
for alg in (
313+
LinearSolveFunction(),
314+
LinearSolveFunction(sol_func),
315+
LinearSolveFunction(sol_func!),
316+
)
317+
test_interface(alg, prob1, prob2)
318+
end
319+
end
349320

350321
end # testset

0 commit comments

Comments
 (0)