Skip to content

Commit 5aefad0

Browse files
committed
works. now need figure out default triage
1 parent caa1bee commit 5aefad0

File tree

5 files changed

+35
-16
lines changed

5 files changed

+35
-16
lines changed

src/LinearSolve.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ include("factorization.jl")
3939
include("simplelu.jl")
4040
include("iterative_wrappers.jl")
4141
include("preconditioners.jl")
42+
include("function_call.jl")
4243
include("default.jl")
4344
include("init.jl")
4445

@@ -47,7 +48,8 @@ isopenblas() = IS_OPENBLAS[]
4748

4849
export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
4950
GenericLUFactorization, SimpleLUFactorization, RFLUFactorization,
50-
UMFPACKFactorization, KLUFactorization
51+
UMFPACKFactorization, KLUFactorization,
52+
FunctionCall
5153
export KrylovJL, KrylovJL_CG, KrylovJL_GMRES, KrylovJL_BICGSTAB, KrylovJL_MINRES,
5254
IterativeSolversJL, IterativeSolversJL_CG, IterativeSolversJL_GMRES,
5355
IterativeSolversJL_BICGSTAB, IterativeSolversJL_MINRES,

src/common.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ function set_cacheval(cache::LinearCache, alg_cache)
6565
return cache
6666
end
6767

68-
init_cacheval(alg::SciMLLinearSolveAlgorithm, A, b, u) = nothing
68+
init_cacheval(alg::SciMLLinearSolveAlgorithm, args...) = nothing
6969

7070
SciMLBase.init(prob::LinearProblem, args...; kwargs...) = SciMLBase.init(prob,nothing,args...;kwargs...)
7171

src/default.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,10 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing,
5959
end
6060

6161
if applicable(ldiv!, A, u)
62-
alg = FunctionCall(ldiv!, (A, u))
62+
alg = FunctionCall(ldiv!, (:A, :u))
6363
SciMLBase.solve(cache, alg, args...; kwargs...)
6464
elseif applicable(ldiv!, u, A, b)
65-
alg = FunctionCall(ldiv!, (u, A, b))
65+
alg = FunctionCall(ldiv!, (:u, :A, :b))
6666
SciMLBase.solve(cache, alg, args...; kwargs...)
6767
end
6868

@@ -125,6 +125,14 @@ function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters, abstol, reltol,
125125
A = A.A
126126
end
127127

128+
if applicable(ldiv!, A, u)
129+
alg = FunctionCall(ldiv!, (:A, :u))
130+
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
131+
elseif applicable(ldiv!, u, A, b)
132+
alg = FunctionCall(ldiv!, (:u, :A, :b))
133+
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
134+
end
135+
128136
# Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when
129137
# it makes sense according to the benchmarks, which is dependent on
130138
# whether MKL or OpenBLAS is being used

src/function_call.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,23 @@
1-
struct FunctionCall{F,A,K} <: SciMLLinearSolveAlgorithm
2-
func::F
3-
args::A
4-
kwargs::K
1+
struct FunctionCall{F,A} <: SciMLLinearSolveAlgorithm
2+
func!::F
3+
argsyms::A
54

6-
function FunctionCall(func::Function, args::Tuple; kwargs...)
7-
@assert iscallable(func)
8-
@assert applicable(func, args; kwargs)
9-
10-
new{typeof(func), typeof(args), typeof(kwargs)}(func, args, kwargs)
5+
function FunctionCall(func!::Function, argsyms::Tuple)
6+
new{typeof(func!), typeof(argsyms)}(func!, argsyms)
117
end
128
end
139

14-
function init_cacheval(alg::FunctionCall, cache::LinearCache)
15-
cache.cacheval
10+
function (f::FunctionCall)(cache::LinearCache)
11+
@unpack func!, argsyms = f
12+
args = [getproperty(cache,argsym) for argsym in argsyms]
13+
func!(args...)
1614
end
1715

1816
function SciMLBase.solve(cache::LinearCache, alg::FunctionCall,
1917
args...; kwargs...)
18+
@unpack u, b = cache
19+
copy!(u, b)
20+
alg(cache)
2021

21-
return
22+
return SciMLBase.build_linear_solution(alg,cache.u,nothing,cache)
2223
end

test/basictests.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ end
4444
y = solve(prob1)
4545
@test A1 * y b1
4646

47+
_prob = LinearProblem(Diagonal(A1), b1; u0=x1)
48+
y = solve(_prob)
49+
@test A1 * y b1
50+
51+
#=
4752
_prob = LinearProblem(SymTridiagonal(A1), b1; u0=x1)
4853
y = solve(_prob)
4954
@test A1 * y ≈ b1
@@ -64,8 +69,10 @@ end
6469
_prob = LinearProblem(sparse(A1), b1; u0=x1)
6570
y = solve(_prob)
6671
@test A1 * y ≈ b1
72+
=#
6773
end
6874

75+
#=
6976
@testset "UMFPACK Factorization" begin
7077
A1 = A/1; b1 = rand(n); x1 = zero(b)
7178
A2 = A/2; b2 = rand(n); x2 = zero(b)
@@ -282,5 +289,6 @@ end
282289
@test sol13.u ≈ sol23.u
283290
@test sol13.u ≈ sol33.u
284291
end
292+
=#
285293

286294
end # testset

0 commit comments

Comments
 (0)