11module LinearSolveForwardDiffExt
22
33using LinearSolve
4- using LinearSolve: SciMLLinearSolveAlgorithm
4+ using LinearSolve: SciMLLinearSolveAlgorithm, __init
55using LinearAlgebra
66using ForwardDiff
77using ForwardDiff: Dual, Partials
88using SciMLBase
99using RecursiveArrayTools
1010
11- using LinearSolve: LUFactorization,
12- QRFactorization,
13- DiagonalFactorization,
14- DirectLdiv!,
15- SparspakFactorization,
16- KLUFactorization,
17- UMFPACKFactorization,
18- KrylovJL,
19- RFLUFactorization,
20- LDLtFactorization,
21- BunchKaufmanFactorization,
22- CHOLMODFactorization,
23- SVDFactorization,
24- CholeskyFactorization,
25- NormalCholeskyFactorization,
26- AppleAccelerateLUFactorization,
27- MKLLUFactorization,
28- DefaultLinearSolver
29-
3011const DualLinearProblem = LinearProblem{
3112 <: Union{Number, <:AbstractArray, Nothing} , iip,
3213 <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} ,
@@ -53,25 +34,6 @@ const DualBLinearProblem = LinearProblem{
5334const DualAbstractLinearProblem = Union{
5435 DualLinearProblem, DualALinearProblem, DualBLinearProblem}
5536
56- const acceptable_algs = Union{LUFactorization,
57- QRFactorization,
58- DiagonalFactorization,
59- DirectLdiv!,
60- SparspakFactorization,
61- KLUFactorization,
62- UMFPACKFactorization,
63- KrylovJL,
64- RFLUFactorization,
65- LDLtFactorization,
66- BunchKaufmanFactorization,
67- CHOLMODFactorization,
68- SVDFactorization,
69- CholeskyFactorization,
70- NormalCholeskyFactorization,
71- AppleAccelerateLUFactorization,
72- MKLLUFactorization,
73- DefaultLinearSolver}
74-
7537LinearSolve. @concrete mutable struct DualLinearCache
7638 linear_cache
7739 dual_type
@@ -159,8 +121,17 @@ function linearsolve_dual_solution(u::AbstractArray, partials,
159121 zip (u, partials_list[i, :] for i in 1 : length (partials_list. u[1 ])))
160122end
161123
162- function SciMLBase. init (
163- prob:: DualAbstractLinearProblem , alg:: acceptable_algs ,
124+ function SciMLBase. init (prob:: DualAbstractLinearProblem , alg:: SciMLLinearSolveAlgorithm , args... ; kwargs... )
125+ return __dual_init (prob, alg, args... ; kwargs... )
126+ end
127+
128+ # Opt out for GenericLUFactorization
129+ function SciMLBase. init (prob:: DualAbstractLinearProblem , alg:: GenericLUFactorization , args... ; kwargs... )
130+ return LinearSolve. __init (prob,alg, args... ; kwargs... )
131+ end
132+
133+ function __dual_init (
134+ prob:: DualAbstractLinearProblem , alg:: SciMLLinearSolveAlgorithm ,
164135 args... ;
165136 alias = LinearAliasSpecifier (),
166137 abstol = LinearSolve. default_tol (real (eltype (prob. b))),
0 commit comments