Skip to content

Commit 0f7bf2b

Browse files
committed
test cases
1 parent bf5ea7f commit 0f7bf2b

File tree

3 files changed

+93
-11
lines changed

3 files changed

+93
-11
lines changed

src/LinearSolve.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,11 @@ isopenblas() = IS_OPENBLAS[]
4848

4949
export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
5050
GenericLUFactorization, SimpleLUFactorization, RFLUFactorization,
51-
UMFPACKFactorization, KLUFactorization,
52-
FunctionCall
51+
UMFPACKFactorization, KLUFactorization
52+
53+
export FunctionCall, LdivBang2Args, LDivBang3Args,
54+
ApplyLDivBang2Args, ApplyLDivBang3Args
55+
5356
export KrylovJL, KrylovJL_CG, KrylovJL_GMRES, KrylovJL_BICGSTAB, KrylovJL_MINRES,
5457
IterativeSolversJL, IterativeSolversJL_CG, IterativeSolversJL_GMRES,
5558
IterativeSolversJL_BICGSTAB, IterativeSolversJL_MINRES,

src/function_call.jl

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,46 @@
1+
2+
""" user passes in inverse function, and arg symbols """
13
struct FunctionCall{F,A} <: SciMLLinearSolveAlgorithm
24
func!::F
3-
argsyms::A
5+
argnames::A
46

5-
function FunctionCall(func!::Function, argsyms::Tuple)
6-
new{typeof(func!), typeof(argsyms)}(func!, argsyms)
7+
function FunctionCall(func!::Function, argnames::Tuple)
8+
# @assert hasfield(::LinearCache, argnames)
9+
# @assert isdefined
10+
new{typeof(func!), typeof(argnames)}(func!, argnames)
711
end
812
end
913

10-
LdivBang2Args() = FunctionCall(ldiv!, (:A, :u))
11-
LdivBang3Args() = FunctionCall(ldiv!, (:u, :A, :b))
12-
1314
function (f::FunctionCall)(cache::LinearCache)
14-
@unpack func!, argsyms = f
15-
args = [getproperty(cache,argsym) for argsym in argsyms]
15+
@unpack func!, argnames = f
16+
args = [getproperty(cache,argname) for argname in argnames]
1617
func!(args...)
1718
end
1819

1920
function SciMLBase.solve(cache::LinearCache, alg::FunctionCall,
2021
args...; kwargs...)
2122
@unpack u, b = cache
22-
copy!(u, b)
2323
alg(cache)
24+
return SciMLBase.build_linear_solution(alg,cache.u,nothing,cache)
25+
end
26+
27+
##
2428

29+
""" apply ldiv!(A, u) """
30+
struct ApplyLDivBang2Args <: SciMLLinearSolveAlgorithm end
31+
function SciMLBase.solve(cache::LinearCache, alg::ApplyLDivBang2Args,
32+
args...; kwargs...)
33+
@unpack A, b, u = cache
34+
copy!(u, b)
35+
LinearAlgebra.ldiv!(A, u)
36+
return SciMLBase.build_linear_solution(alg,cache.u,nothing,cache)
37+
end
38+
39+
""" apply ldiv!(u, A, b) """
40+
struct ApplyLDivBang3Args <: SciMLLinearSolveAlgorithm end
41+
function SciMLBase.solve(cache::LinearCache, alg::ApplyLDivBang3Args,
42+
args...; kwargs...)
43+
@unpack A, b, u = cache
44+
LinearAlgebra.ldiv!(u, A, b)
2545
return SciMLBase.build_linear_solution(alg,cache.u,nothing,cache)
2646
end

test/basictests.jl

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,64 @@ end
3737

3838
@testset "LinearSolve" begin
3939

40+
@testset "Apply Function" begin
41+
42+
@testset "Diagonal Type" begin
43+
A1 = rand(n) |> Diagonal; b1 = rand(n); x1 = zero(b1)
44+
A2 = rand(n) |> Diagonal; b2 = rand(n); x2 = zero(b1)
45+
46+
prob1 = LinearProblem(A1, b1; u0=x1)
47+
prob2 = LinearProblem(A1, b1; u0=x1)
48+
49+
for alg in (
50+
FunctionCall(LinearAlgebra.ldiv!, (:u, :A, :b)),
51+
ApplyLDivBang2Args(),
52+
ApplyLDivBang3Args(),
53+
)
54+
test_interface(alg, prob1, prob2)
55+
end
56+
end
57+
58+
@testset "Custom Type" begin
59+
60+
struct MyDiag
61+
d
62+
end
63+
64+
# overloads
65+
(D::MyDiag)(du, u, p, t) = mul!(du, D, u)
66+
Base.:*(D::MyDiag, u) = Diagonal(D.d) * u
67+
68+
Base.copy(D::MyDiag) = copy(D.d) |> MyDiag
69+
70+
LinearAlgebra.mul!(y, D::MyDiag, x) = mul!(y, Diagonal(D.d), x)
71+
LinearAlgebra.ldiv!(y, D::MyDiag, x) = ldiv!(y, Diagonal(D.d), x)
72+
LinearAlgebra.ldiv!(D::MyDiag, x) = ldiv!(Diagonal(D.d), x)
73+
74+
# custom inverse function
75+
function my_inv!(D::MyDiag, u, b)
76+
@. u = b / D.d
77+
end
78+
79+
A1 = rand(n) |> MyDiag; b1 = rand(n); x1 = zero(b1)
80+
A2 = rand(n) |> MyDiag; b2 = rand(n); x2 = zero(b1)
81+
82+
prob1 = LinearProblem(A1, b1; u0=x1)
83+
prob2 = LinearProblem(A1, b1; u0=x1)
84+
85+
for alg in (
86+
FunctionCall(LinearAlgebra.ldiv!, (:u, :A, :b)),
87+
FunctionCall(my_inv!, (:A, :u, :b)),
88+
ApplyLDivBang2Args(),
89+
ApplyLDivBang3Args(),
90+
)
91+
test_interface(alg, prob1, prob2)
92+
end
93+
end
94+
95+
end
96+
#=
97+
4098
@testset "Default Linear Solver" begin
4199
test_interface(nothing, prob1, prob2)
42100
@@ -286,5 +344,6 @@ end
286344
@test sol13.u ≈ sol23.u
287345
@test sol13.u ≈ sol33.u
288346
end
347+
=#
289348

290349
end # testset

0 commit comments

Comments
 (0)