Skip to content

Commit caa1bee

Browse files
committed
functioncall struct
1 parent 1a8261b commit caa1bee

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

src/default.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,19 @@ end
5353

5454
function SciMLBase.solve(cache::LinearCache, alg::Nothing,
5555
args...; kwargs...)
56-
@unpack A = cache
56+
@unpack A, b, u = cache
5757
if A isa DiffEqArrayOperator
5858
A = A.A
5959
end
6060

61+
if applicable(ldiv!, A, u)
62+
alg = FunctionCall(ldiv!, (A, u))
63+
SciMLBase.solve(cache, alg, args...; kwargs...)
64+
elseif applicable(ldiv!, u, A, b)
65+
alg = FunctionCall(ldiv!, (u, A, b))
66+
SciMLBase.solve(cache, alg, args...; kwargs...)
67+
end
68+
6169
# Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when
6270
# it makes sense according to the benchmarks, which is dependent on
6371
# whether MKL or OpenBLAS is being used

src/function_call.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
struct FunctionCall{F,A,K} <: SciMLLinearSolveAlgorithm
2+
func::F
3+
args::A
4+
kwargs::K
5+
6+
function FunctionCall(func::Function, args::Tuple; kwargs...)
7+
@assert iscallable(func)
8+
@assert applicable(func, args; kwargs)
9+
10+
new{typeof(func), typeof(args), typeof(kwargs)}(func, args, kwargs)
11+
end
12+
end
13+
14+
function init_cacheval(alg::FunctionCall, cache::LinearCache)
15+
cache.cacheval
16+
end
17+
18+
function SciMLBase.solve(cache::LinearCache, alg::FunctionCall,
19+
args...; kwargs...)
20+
21+
return
22+
end

0 commit comments

Comments
 (0)