Skip to content

Commit 499e760

Browse files
authored
Make GenericLUFactorization bypass overloads when used on DualLinear Problems (#685)
* add list of algs acceptable for going through Dual overloads * make sure that GenericLUFactorization doesn't go through DualLinearCache * add test * fix test * make acceptable_algs in to const * make opt out instead * make sure kwargs are passed in * indirect import
1 parent 2590a25 commit 499e760

File tree

3 files changed

+24
-11
lines changed

3 files changed

+24
-11
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module LinearSolveForwardDiffExt
22

33
using LinearSolve
4-
using LinearSolve: SciMLLinearSolveAlgorithm
4+
using LinearSolve: SciMLLinearSolveAlgorithm, __init
55
using LinearAlgebra
66
using ForwardDiff
77
using ForwardDiff: Dual, Partials
@@ -121,8 +121,17 @@ function linearsolve_dual_solution(u::AbstractArray, partials,
121121
zip(u, partials_list[i, :] for i in 1:length(partials_list.u[1])))
122122
end
123123

124-
function SciMLBase.init(
125-
prob::DualAbstractLinearProblem, alg::LinearSolve.SciMLLinearSolveAlgorithm,
124+
function SciMLBase.init(prob::DualAbstractLinearProblem, alg::SciMLLinearSolveAlgorithm, args...; kwargs...)
125+
return __dual_init(prob, alg, args...; kwargs...)
126+
end
127+
128+
# Opt out for GenericLUFactorization
129+
function SciMLBase.init(prob::DualAbstractLinearProblem, alg::GenericLUFactorization, args...; kwargs...)
130+
return __init(prob,alg, args...; kwargs...)
131+
end
132+
133+
function __dual_init(
134+
prob::DualAbstractLinearProblem, alg::SciMLLinearSolveAlgorithm,
126135
args...;
127136
alias = LinearAliasSpecifier(),
128137
abstol = LinearSolve.default_tol(real(eltype(prob.b))),

src/common.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,11 @@ function __init_u0_from_Ab(A, b)
137137
end
138138
__init_u0_from_Ab(::SMatrix{S1, S2}, b) where {S1, S2} = zeros(SVector{S2, eltype(b)})
139139

140-
function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
140+
function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, args...; kwargs...)
141+
__init(prob, alg, args...; kwargs...)
142+
end
143+
144+
function __init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
141145
args...;
142146
alias = LinearAliasSpecifier(),
143147
abstol = default_tol(real(eltype(prob.b))),

test/forwarddiff_overloads.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,6 @@ backslash_x_p = A \ new_b
8787
@test (x_p, backslash_x_p, rtol = 1e-9)
8888

8989
# Nested Duals
90-
function h(p)
91-
(A = [p[1] p[2]+1 p[2]^3;
92-
3*p[1] p[1]+5 p[2] * p[1]-4;
93-
p[2]^2 9*p[1] p[2]],
94-
b = [p[1] + 1, p[2] * 2, p[1]^2])
95-
end
96-
9790
A,
9891
b = h([ForwardDiff.Dual(ForwardDiff.Dual(5.0, 1.0, 0.0), 1.0, 0.0),
9992
ForwardDiff.Dual(ForwardDiff.Dual(5.0, 1.0, 0.0), 0.0, 1.0)])
@@ -193,3 +186,10 @@ overload_x_p = solve(prob, UMFPACKFactorization())
193186
backslash_x_p = A \ b
194187

195188
@test (overload_x_p, backslash_x_p, rtol = 1e-9)
189+
190+
191+
# Test that GenericLU doesn't create a DualLinearCache
192+
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
193+
194+
prob = LinearProblem(A, b)
195+
@test init(prob, GenericLUFactorization()) isa LinearSolve.LinearCache

0 commit comments

Comments
 (0)