Skip to content

Commit 7f6b1e8

Browse files
committed
make opt out instead
1 parent de4685d commit 7f6b1e8

File tree

2 files changed

+17
-42
lines changed

2 files changed

+17
-42
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 12 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,13 @@
11
module LinearSolveForwardDiffExt
22

33
using LinearSolve
4-
using LinearSolve: SciMLLinearSolveAlgorithm
4+
using LinearSolve: SciMLLinearSolveAlgorithm, __init
55
using LinearAlgebra
66
using ForwardDiff
77
using ForwardDiff: Dual, Partials
88
using SciMLBase
99
using RecursiveArrayTools
1010

11-
using LinearSolve: LUFactorization,
12-
QRFactorization,
13-
DiagonalFactorization,
14-
DirectLdiv!,
15-
SparspakFactorization,
16-
KLUFactorization,
17-
UMFPACKFactorization,
18-
KrylovJL,
19-
RFLUFactorization,
20-
LDLtFactorization,
21-
BunchKaufmanFactorization,
22-
CHOLMODFactorization,
23-
SVDFactorization,
24-
CholeskyFactorization,
25-
NormalCholeskyFactorization,
26-
AppleAccelerateLUFactorization,
27-
MKLLUFactorization,
28-
DefaultLinearSolver
29-
3011
const DualLinearProblem = LinearProblem{
3112
<:Union{Number, <:AbstractArray, Nothing}, iip,
3213
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}},
@@ -53,25 +34,6 @@ const DualBLinearProblem = LinearProblem{
5334
const DualAbstractLinearProblem = Union{
5435
DualLinearProblem, DualALinearProblem, DualBLinearProblem}
5536

56-
const acceptable_algs = Union{LUFactorization,
57-
QRFactorization,
58-
DiagonalFactorization,
59-
DirectLdiv!,
60-
SparspakFactorization,
61-
KLUFactorization,
62-
UMFPACKFactorization,
63-
KrylovJL,
64-
RFLUFactorization,
65-
LDLtFactorization,
66-
BunchKaufmanFactorization,
67-
CHOLMODFactorization,
68-
SVDFactorization,
69-
CholeskyFactorization,
70-
NormalCholeskyFactorization,
71-
AppleAccelerateLUFactorization,
72-
MKLLUFactorization,
73-
DefaultLinearSolver}
74-
7537
LinearSolve.@concrete mutable struct DualLinearCache
7638
linear_cache
7739
dual_type
@@ -159,8 +121,17 @@ function linearsolve_dual_solution(u::AbstractArray, partials,
159121
zip(u, partials_list[i, :] for i in 1:length(partials_list.u[1])))
160122
end
161123

162-
function SciMLBase.init(
163-
prob::DualAbstractLinearProblem, alg::acceptable_algs,
124+
function SciMLBase.init(prob::DualAbstractLinearProblem, alg::SciMLLinearSolveAlgorithm, args...; kwargs...)
125+
return __dual_init(prob, alg, args...; kwargs...)
126+
end
127+
128+
# Opt out for GenericLUFactorization
129+
function SciMLBase.init(prob::DualAbstractLinearProblem, alg::GenericLUFactorization, args...; kwargs...)
130+
return LinearSolve.__init(prob,alg, args...; kwargs...)
131+
end
132+
133+
function __dual_init(
134+
prob::DualAbstractLinearProblem, alg::SciMLLinearSolveAlgorithm,
164135
args...;
165136
alias = LinearAliasSpecifier(),
166137
abstol = LinearSolve.default_tol(real(eltype(prob.b))),

src/common.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,11 @@ function __init_u0_from_Ab(A, b)
137137
end
138138
__init_u0_from_Ab(::SMatrix{S1, S2}, b) where {S1, S2} = zeros(SVector{S2, eltype(b)})
139139

140-
function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
140+
function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, args...; kwargs...)
141+
__init(prob, alg, args..., kwargs...)
142+
end
143+
144+
function __init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
141145
args...;
142146
alias = LinearAliasSpecifier(),
143147
abstol = default_tol(real(eltype(prob.b))),

0 commit comments

Comments
 (0)