Skip to content

Commit c960f8f

Browse files
Merge pull request #112 from vpuri3/vp-ldiv
custom linear solve function
2 parents 5a75de5 + b692761 commit c960f8f

File tree

7 files changed

+118
-2
lines changed

7 files changed

+118
-2
lines changed

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ makedocs(
2929
],
3030
"Advanced" => Any[
3131
"advanced/developing.md"
32+
"advanced/custom.md"
3233
]
3334
]
3435
)

docs/src/advanced/custom.md

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Passing in a Custom Linear Solver
2+
Julia users are building a wide variety of applications in the SciML ecosystem,
3+
often requiring problem-specific handling of their linear solves. As existing solvers in `LinearSolve.jl` may not
4+
be optimally suited for novel applications, it is essential for the linear solve
5+
interface to be easily extendable by users. To that end, the linear solve algorithm
6+
`LinearSolveFunction()` accepts a user-defined function for handling the solve. A
7+
user can pass in their custom linear solve function, say `my_linsolve`, to
8+
`LinearSolveFunction()`. A contrived example of solving a linear system with a custom solver is below.
9+
```julia
10+
using LinearSolve, LinearAlgebra
11+
12+
function my_linsolve(A,b,u,p,newA,Pl,Pr,solverdata;verbose=true, kwargs...)
13+
if verbose == true
14+
println("solving Ax=b")
15+
end
16+
u = A \ b
17+
return u
18+
end
19+
20+
prob = LinearProblem(Diagonal(rand(4)), rand(4))
21+
alg = LinearSolveFunction(my_linsolve),
22+
sol = solve(prob, alg)
23+
```
24+
The inputs to the function are as follows:
25+
- `A`, the linear operator
26+
- `b`, the right-hand-side
27+
- `u`, the solution initialized as `zero(b)`,
28+
- `p`, a set of parameters
29+
- `newA`, a `Bool` which is `true` if `A` has been modified since last solve
30+
- `Pl`, left-preconditioner
31+
- `Pr`, right-preconditioner
32+
- `solverdata`, solver cache set to `nothing` if solver hasn't been initialized
33+
- `kwargs`, standard SciML keyword arguments such as `verbose`, `maxiters`,
34+
`abstol`, `reltol`
35+
The function `my_linsolve` must accept the above specified arguments, and return
36+
the solution, `u`. As memory for `u` is already allocated, the user may choose
37+
to modify `u` in place as follows:
38+
```julia
39+
function my_linsolve!(A,b,u,p,newA,Pl,Pr,solverdata;verbose=true, kwargs...)
40+
if verbose == true
41+
println("solving Ax=b")
42+
end
43+
u .= A \ b # in place
44+
return u
45+
end
46+
47+
alg = LinearSolveFunction(my_linsolve!)
48+
sol = solve(prob, alg)
49+
```
50+
Finally, note that `LinearSolveFunction()` dispatches to the default linear solve
51+
algorithm handling if no arguments are passed in.
52+
```julia
53+
alg = LinearSolveFunction()
54+
sol = solve(prob, alg) # same as solve(prob, nothing)
55+
```

docs/src/solvers/solvers.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,17 @@ Pardiso.jl's methods are also known to be very efficient sparse linear solvers.
1919

2020
As sparse matrices get larger, iterative solvers tend to get more efficient than
2121
factorization methods if a lower tolerance of the solution is required.
22-
Krylov.jl works with CPUs and GPUs and tends to be more efficient than other
22+
23+
IterativeSolvers.jl uses a low-rank Q update in its GMRES so it tends to be
24+
faster than Krylov.jl for CPU-based arrays, but it's only compatible with
25+
CPU-based arrays while Krylov.jl is more general and will support accelerators
26+
like CUDA. Krylov.jl works with CPUs and GPUs and tends to be more efficient than other
2327
Krylov-based methods.
2428

29+
Finally, a user can pass a custom function ofr the linear solve using
30+
`LinearSolveFunction()` if existing solvers are not optimal for their application.
31+
The interface is detailed [here](#passing-in-a-custom-linear-solver)
32+
2533
## Full List of Methods
2634

2735
### RecursiveFactorization.jl

src/LinearSolve.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@ using Reexport
2626
abstract type SciMLLinearSolveAlgorithm <: SciMLBase.AbstractLinearAlgorithm end
2727
abstract type AbstractFactorization <: SciMLLinearSolveAlgorithm end
2828
abstract type AbstractKrylovSubspaceMethod <: SciMLLinearSolveAlgorithm end
29+
abstract type AbstractSolveFunction <: SciMLLinearSolveAlgorithm end
2930

3031
# Traits
3132

3233
needs_concrete_A(alg::AbstractFactorization) = true
3334
needs_concrete_A(alg::AbstractKrylovSubspaceMethod) = false
35+
needs_concrete_A(alg::AbstractSolveFunction) = false
3436

3537
# Code
3638

@@ -39,6 +41,7 @@ include("factorization.jl")
3941
include("simplelu.jl")
4042
include("iterative_wrappers.jl")
4143
include("preconditioners.jl")
44+
include("solve_function.jl")
4245
include("default.jl")
4346
include("init.jl")
4447

@@ -48,6 +51,9 @@ isopenblas() = IS_OPENBLAS[]
4851
export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
4952
GenericLUFactorization, SimpleLUFactorization, RFLUFactorization,
5053
UMFPACKFactorization, KLUFactorization
54+
55+
export LinearSolveFunction
56+
5157
export KrylovJL, KrylovJL_CG, KrylovJL_GMRES, KrylovJL_BICGSTAB, KrylovJL_MINRES,
5258
IterativeSolversJL, IterativeSolversJL_CG, IterativeSolversJL_GMRES,
5359
IterativeSolversJL_BICGSTAB, IterativeSolversJL_MINRES,

src/common.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ function set_cacheval(cache::LinearCache, alg_cache)
6565
return cache
6666
end
6767

68-
init_cacheval(alg::SciMLLinearSolveAlgorithm, A, b, u) = nothing
68+
init_cacheval(alg::SciMLLinearSolveAlgorithm, args...) = nothing
6969

7070
SciMLBase.init(prob::LinearProblem, args...; kwargs...) = SciMLBase.init(prob,nothing,args...;kwargs...)
7171

src/solve_function.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#
2+
struct LinearSolveFunction{F} <: AbstractSolveFunction
3+
solve_func::F
4+
end
5+
6+
function SciMLBase.solve(cache::LinearCache, alg::LinearSolveFunction,
7+
args...; kwargs...)
8+
@unpack A,b,u,p,isfresh,Pl,Pr,cacheval = cache
9+
@unpack solve_func = alg
10+
11+
u = solve_func(A,b,u,p,isfresh,Pl,Pr,cacheval;kwargs...)
12+
cache = set_u(cache, u)
13+
14+
return SciMLBase.build_linear_solution(alg,cache.u,nothing,cache)
15+
end

test/basictests.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,4 +287,35 @@ end
287287
@test sol13.u sol33.u
288288
end
289289

290+
@testset "Solve Function" begin
291+
292+
A1 = rand(n) |> Diagonal; b1 = rand(n); x1 = zero(b1)
293+
A2 = rand(n) |> Diagonal; b2 = rand(n); x2 = zero(b1)
294+
295+
function sol_func(A,b,u,p,newA,Pl,Pr,solverdata;verbose=true, kwargs...)
296+
if verbose == true
297+
println("out-of-place solve")
298+
end
299+
u = A \ b
300+
end
301+
302+
function sol_func!(A,b,u,p,newA,Pl,Pr,solverdata;verbose=true, kwargs...)
303+
if verbose == true
304+
println("in-place solve")
305+
end
306+
ldiv!(u,A,b)
307+
end
308+
309+
prob1 = LinearProblem(A1, b1; u0=x1)
310+
prob2 = LinearProblem(A1, b1; u0=x1)
311+
312+
for alg in (
313+
LinearSolveFunction(sol_func),
314+
LinearSolveFunction(sol_func!),
315+
)
316+
317+
test_interface(alg, prob1, prob2)
318+
end
319+
end
320+
290321
end # testset

0 commit comments

Comments
 (0)