Skip to content

Commit 6d302c2

Browse files
committed
add methods defaultalg to support scimlops
1 parent 6f3ae84 commit 6d302c2

File tree

3 files changed

+67
-25
lines changed

3 files changed

+67
-25
lines changed

src/common.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ end
2525
"""
2626
$(SIGNATURES)
2727
"""
28+
# TODO - this should modify OperatorAssumption??
2829
function set_A(cache::LinearCache, A)
2930
@set! cache.A = A
3031
@set! cache.isfresh = true
@@ -94,7 +95,7 @@ function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorith
9495
verbose::Bool = false,
9596
Pl = IdentityOperator{size(prob.A, 1)}(),
9697
Pr = IdentityOperator{size(prob.A, 2)}(),
97-
assumptions = OperatorAssumptions(),
98+
assumptions = OperatorAssumptions(issquare(prob.A)),
9899
kwargs...)
99100
@unpack A, b, u0, p = prob
100101

src/default.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ end
3636
function defaultalg(A::Diagonal, b, ::OperatorAssumptions{true})
3737
DiagonalFactorization()
3838
end
39+
# TODO - Diagonal matrices are always square
3940
function defaultalg(A::Diagonal, b, ::OperatorAssumptions{false})
4041
DiagonalFactorization()
4142
end
@@ -76,13 +77,21 @@ function defaultalg(A, b::GPUArraysCore.AbstractGPUArray, ::OperatorAssumptions{
7677
end
7778

7879
function defaultalg(A::SciMLBase.AbstractSciMLOperator, b,
79-
assumptions::OperatorAssumptions)
80+
assumptions::OperatorAssumptions{true})
81+
if has_ldiv!(A)
82+
return DirectLdiv!()
83+
end
84+
8085
KrylovJL_GMRES()
8186
end
8287

8388
# Ambiguity handling
8489
function defaultalg(A::SciMLBase.AbstractSciMLOperator, b,
8590
assumptions::OperatorAssumptions{Nothing})
91+
if has_ldiv!(A)
92+
return DirectLdiv!()
93+
end
94+
8695
KrylovJL_GMRES()
8796
end
8897

test/basictests.jl

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,20 @@ function test_interface(alg, prob1, prob2)
2727
b2 = prob2.b
2828
x2 = prob2.u0
2929

30-
y = solve(prob1, alg; cache_kwargs...)
31-
@test A1 * y b1
30+
sol = solve(prob1, alg; cache_kwargs...)
31+
@test A1 * sol.u b1
3232

3333
cache = SciMLBase.init(prob1, alg; cache_kwargs...) # initialize cache
34-
y = solve(cache)
35-
@test A1 * y b1
34+
sol = solve(cache)
35+
@test A1 * sol.u b1
3636

37-
cache = LinearSolve.set_A(cache, copy(A2))
38-
y = solve(cache; cache_kwargs...)
39-
@test A2 * y b1
37+
cache = LinearSolve.set_A(cache, deepcopy(A2))
38+
sol = solve(cache; cache_kwargs...)
39+
@test A2 * sol.u b1
4040

4141
cache = LinearSolve.set_b(cache, b2)
42-
y = solve(cache; cache_kwargs...)
43-
@test A2 * y b2
42+
sol = solve(cache; cache_kwargs...)
43+
@test A2 * sol.u b2
4444

4545
return
4646
end
@@ -359,26 +359,58 @@ end
359359
b2 = rand(n)
360360
x2 = zero(b1)
361361

362-
function sol_func(A, b, u, p, newA, Pl, Pr, solverdata; verbose = true, kwargs...)
363-
if verbose == true
364-
println("out-of-place solve")
362+
@testset "LinearSolveFunction" begin
363+
function sol_func(A, b, u, p, newA, Pl, Pr, solverdata; verbose = true, kwargs...)
364+
if verbose == true
365+
println("out-of-place solve")
366+
end
367+
u = A \ b
365368
end
366-
u = A \ b
367-
end
368369

369-
function sol_func!(A, b, u, p, newA, Pl, Pr, solverdata; verbose = true, kwargs...)
370-
if verbose == true
371-
println("in-place solve")
370+
function sol_func!(A, b, u, p, newA, Pl, Pr, solverdata; verbose = true, kwargs...)
371+
if verbose == true
372+
println("in-place solve")
373+
end
374+
ldiv!(u, A, b)
375+
end
376+
377+
prob1 = LinearProblem(A1, b1; u0 = x1)
378+
prob2 = LinearProblem(A1, b1; u0 = x1)
379+
380+
for alg in (LinearSolveFunction(sol_func),
381+
LinearSolveFunction(sol_func!))
382+
test_interface(alg, prob1, prob2)
372383
end
373-
ldiv!(u, A, b)
374384
end
375385

376-
prob1 = LinearProblem(A1, b1; u0 = x1)
377-
prob2 = LinearProblem(A1, b1; u0 = x1)
386+
@testset "DirectLdiv!" begin
387+
function get_operator(A, u)
388+
F = lu(A)
378389

379-
for alg in (LinearSolveFunction(sol_func),
380-
LinearSolveFunction(sol_func!))
381-
test_interface(alg, prob1, prob2)
390+
function f(du, u, p, t)
391+
println("using FunctionOperator mul!")
392+
mul!(du, A, u)
393+
end
394+
395+
function fi(du, u, p, t)
396+
println("using FunctionOperator ldiv!")
397+
ldiv!(du, F, u)
398+
end
399+
400+
FunctionOperator(f, u, u; isinplace=true, op_inverse=fi)
401+
end
402+
403+
op1 = get_operator(A1, x1*0)
404+
op2 = get_operator(A2, x2*0)
405+
406+
prob1 = LinearProblem(op1, b1; u0 = x1)
407+
prob2 = LinearProblem(op2, b2; u0 = x2)
408+
409+
@test LinearSolve.defaultalg(op1, x1) isa DirectLdiv!
410+
@test LinearSolve.defaultalg(op2, x2) isa DirectLdiv!
411+
412+
test_interface(DirectLdiv!(), prob1, prob2)
413+
test_interface(nothing, prob1, prob2)
382414
end
383415
end
384416
end # testset

0 commit comments

Comments
 (0)