Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/src/solvers/solvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ use `Krylov_GMRES()`.

## Full List of Methods

### Polyalgorithms

```@docs
LinearSolve.DefaultLinearSolver
```

### RecursiveFactorization.jl

!!! note
Expand Down
15 changes: 15 additions & 0 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,23 @@ EnumX.@enumx DefaultAlgorithmChoice begin
KrylovJL_LSMR
end

"""
DefaultLinearSolver(;safetyfallback=true)

The default linear solver. This is the algorithm chosen when `solve(prob)`
is called. It's a polyalgorithm that detects the optimal method for a given
`A, b` and hardware (Intel, AMD, GPU, etc.).

## Keyword Arguments

* `safetyfallback`: determines whether to fallback to a column-pivoted QR factorization
when an LU factorization fails. This can be required if `A` is rank-deficient. Defaults
to true.
Comment on lines +131 to +133
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
* `safetyfallback`: determines whether to fallback to a column-pivoted QR factorization
when an LU factorization fails. This can be required if `A` is rank-deficient. Defaults
to true.
- `safetyfallback`: determines whether to fallback to a column-pivoted QR factorization
when an LU factorization fails. This can be required if `A` is rank-deficient. Defaults
to true.

"""
struct DefaultLinearSolver <: SciMLLinearSolveAlgorithm
alg::DefaultAlgorithmChoice.T
safetyfallback::Bool
DefaultLinearSolver(alg; safetyfallback=true) = new(alg,safetyfallback)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
DefaultLinearSolver(alg; safetyfallback=true) = new(alg,safetyfallback)
DefaultLinearSolver(alg; safetyfallback = true) = new(alg, safetyfallback)

end

const BLASELTYPES = Union{Float32, Float64, ComplexF32, ComplexF64}
Expand Down
36 changes: 31 additions & 5 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ end
ex = Expr(:if, ex.args...)
end

# Handle special case of Column-pivoted QR fallback for LU
function __setfield!(cache::DefaultLinearSolverInit, alg::DefaultLinearSolver, v::LinearAlgebra.QRPivoted)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function __setfield!(cache::DefaultLinearSolverInit, alg::DefaultLinearSolver, v::LinearAlgebra.QRPivoted)
function __setfield!(cache::DefaultLinearSolverInit,
alg::DefaultLinearSolver, v::LinearAlgebra.QRPivoted)

setfield!(cache, :QRFactorizationPivoted, v)
end

# Legacy fallback
# For SciML algorithms already using `defaultalg`, all assume square matrix.
defaultalg(A, b) = defaultalg(A, b, OperatorAssumptions(true))
Expand Down Expand Up @@ -352,11 +357,32 @@ end
kwargs...)
ex = :()
for alg in first.(EnumX.symbol_map(DefaultAlgorithmChoice.T))
newex = quote
sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...)
SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache;
retcode = sol.retcode,
iters = sol.iters, stats = sol.stats)
if alg in Symbol.((DefaultAlgorithmChoice.LUFactorization,
DefaultAlgorithmChoice.RFLUFactorization,
DefaultAlgorithmChoice.MKLLUFactorization,
DefaultAlgorithmChoice.AppleAccelerateLUFactorization,
DefaultAlgorithmChoice.GenericLUFactorization))
newex = quote
sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...)
if sol.retcode === ReturnCode.Failure && alg.safetyfallback
## TODO: Add verbosity logging here about using the fallback
sol = SciMLBase.solve!(cache, QRFactorization(ColumnNorm()), args...; kwargs...)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
sol = SciMLBase.solve!(cache, QRFactorization(ColumnNorm()), args...; kwargs...)
sol = SciMLBase.solve!(
cache, QRFactorization(ColumnNorm()), args...; kwargs...)

SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache;
retcode = sol.retcode,
iters = sol.iters, stats = sol.stats)
else
SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache;
retcode = sol.retcode,
iters = sol.iters, stats = sol.stats)
end
end
else
newex = quote
sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...)
SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache;
retcode = sol.retcode,
iters = sol.iters, stats = sol.stats)
end
end
alg_enum = getproperty(LinearSolve.DefaultAlgorithmChoice, alg)
ex = if ex == :()
Expand Down
16 changes: 15 additions & 1 deletion test/default_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,18 @@ prob = LinearProblem(A, b)
@test_broken SciMLBase.successful_retcode(solve(prob))

prob2 = LinearProblem(A2, b)
@test SciMLBase.successful_retcode(solve(prob2))
@test SciMLBase.successful_retcode(solve(prob2))

# Column-Pivoted QR fallback on failed LU
A = [1.0 0 0 0
0 1 0 0
0 0 1 0
Comment on lines +165 to +167
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
0 1 0 0
0 0 1 0
0 0 0 0]
0 1 0 0
0 0 1 0
0 0 0 0]

0 0 0 0]
b = rand(4)
prob = LinearProblem(A, b)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
sol = solve(prob, LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization; safetyfallback=false))
sol = solve(prob,
LinearSolve.DefaultLinearSolver(
LinearSolve.DefaultAlgorithmChoice.LUFactorization; safetyfallback = false))

sol = solve(prob, LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization; safetyfallback=false))
@test sol.retcode === ReturnCode.Failure
@test sol.u == zeros(4)

sol = solve(prob)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@test sol.u svd(A)\b
@test sol.u svd(A) \ b

@test sol.u ≈ svd(A)\b
Loading