-
-
Notifications
You must be signed in to change notification settings - Fork 72
Overloads for LinearProblems with ForwardDiff Dual numbers #621
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 32 commits
Commits
Show all changes
34 commits
Select commit
Hold shift + click to select a range
c57c9d8
add LinearSolveForwardDiffExt.jl
jClugstor 677570f
add partial linsolve
jClugstor c154d25
fix up the linear dual solution
jClugstor fc8c4b5
add ForwardDiffExt to project
jClugstor c419c48
use real solve
jClugstor d6bddf9
add ForwardDiff as weakdep
jClugstor 81030fa
add imports and fix partial_val
jClugstor d7c56a8
add test
jClugstor f706c8f
add tests to runtest
jClugstor f95b799
format
jClugstor 6352024
rm debug message
jClugstor e6cda65
use inits and caches
jClugstor 501f07d
rearrange
jClugstor 9aa8b19
format
jClugstor 277c4f8
bring in linalg, add tols to tests
jClugstor 313e286
make sure using nonmutated A
jClugstor 922f7ec
dual cache should have original A and b
jClugstor 3547ec7
rearrange, make sure that dualcache works
jClugstor 9cd4e19
reinit! not needed for now
jClugstor b2a4291
correct setproperty! for DualLinearCache
jClugstor 680aec6
add tests for updating cache
jClugstor f9cd2fe
enable dual u0
jClugstor 1b48666
use new_u0
jClugstor b39ce87
reuse primal cache for Dual computation
jClugstor e5761c8
redundant line
jClugstor 9b69358
make sure u0 is correct type
jClugstor f55639a
add tests for iterative and u0
jClugstor 51ce056
make sure that linearcache.b is reset after dual solve
jClugstor 3565d9b
fix test
jClugstor 6f33486
forward steproperty and getproperty more
jClugstor d05ad09
use correct u0
jClugstor fb0626f
add test for updating one of A or b
jClugstor 613e9aa
p can be Any
jClugstor d690f1f
use remake instead
jClugstor File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,241 @@ | ||
module LinearSolveForwardDiffExt | ||
|
||
using LinearSolve | ||
using LinearAlgebra | ||
using ForwardDiff | ||
using ForwardDiff: Dual, Partials | ||
using SciMLBase | ||
using RecursiveArrayTools | ||
|
||
const DualLinearProblem = LinearProblem{ | ||
<:Union{Number, <:AbstractArray, Nothing}, iip, | ||
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, | ||
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, | ||
<:Union{Number, <:AbstractArray, SciMLBase.NullParameters} | ||
} where {iip, T, V, P} | ||
|
||
const DualALinearProblem = LinearProblem{ | ||
<:Union{Number, <:AbstractArray, Nothing}, | ||
iip, | ||
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, | ||
<:Union{Number, <:AbstractArray}, | ||
<:Union{Number, <:AbstractArray, SciMLBase.NullParameters} | ||
} where {iip, T, V, P} | ||
|
||
const DualBLinearProblem = LinearProblem{ | ||
<:Union{Number, <:AbstractArray, Nothing}, | ||
iip, | ||
<:Union{Number, <:AbstractArray}, | ||
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, | ||
<:Union{Number, <:AbstractArray, SciMLBase.NullParameters} | ||
} where {iip, T, V, P} | ||
|
||
const DualAbstractLinearProblem = Union{ | ||
DualLinearProblem, DualALinearProblem, DualBLinearProblem} | ||
|
||
LinearSolve.@concrete mutable struct DualLinearCache | ||
linear_cache | ||
dual_type | ||
partials_A | ||
partials_b | ||
end | ||
|
||
function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...) | ||
# Solve the primal problem | ||
dual_u0 = copy(cache.linear_cache.u) | ||
sol = solve!(cache.linear_cache, alg, args...; kwargs...) | ||
primal_b = copy(cache.linear_cache.b) | ||
uu = sol.u | ||
|
||
primal_sol = deepcopy(sol) | ||
|
||
# Solves Dual partials separately | ||
∂_A = cache.partials_A | ||
∂_b = cache.partials_b | ||
|
||
rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b) | ||
|
||
partial_cache = cache.linear_cache | ||
partial_cache.u = dual_u0 | ||
|
||
for i in eachindex(rhs_list) | ||
partial_cache.b = rhs_list[i] | ||
rhs_list[i] = copy(solve!(partial_cache, alg, args...; kwargs...).u) | ||
end | ||
|
||
# Reset to the original `b`, users will expect that `b` doesn't change if they don't tell it to | ||
partial_cache.b = primal_b | ||
|
||
partial_sols = rhs_list | ||
|
||
primal_sol, partial_sols | ||
end | ||
|
||
function xp_linsolve_rhs(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}}, | ||
∂_b::Union{<:Partials, <:AbstractArray{<:Partials}}) | ||
A_list = partials_to_list(∂_A) | ||
b_list = partials_to_list(∂_b) | ||
|
||
Auu = [A * uu for A in A_list] | ||
|
||
return b_list .- Auu | ||
end | ||
|
||
function xp_linsolve_rhs( | ||
uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}}, ∂_b::Nothing) | ||
A_list = partials_to_list(∂_A) | ||
|
||
Auu = [A * uu for A in A_list] | ||
|
||
return -Auu | ||
end | ||
|
||
function xp_linsolve_rhs( | ||
uu, ∂_A::Nothing, ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}}) | ||
b_list = partials_to_list(∂_b) | ||
b_list | ||
end | ||
|
||
function SciMLBase.solve(prob::DualAbstractLinearProblem, args...; kwargs...) | ||
return solve(prob, nothing, args...; kwargs...) | ||
end | ||
|
||
function SciMLBase.solve(prob::DualAbstractLinearProblem, ::Nothing, args...; | ||
assump = OperatorAssumptions(issquare(prob.A)), kwargs...) | ||
return solve(prob, LinearSolve.defaultalg(prob.A, prob.b, assump), args...; kwargs...) | ||
end | ||
|
||
function SciMLBase.solve(prob::DualAbstractLinearProblem, | ||
alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; kwargs...) | ||
solve!(init(prob, alg, args...; kwargs...)) | ||
end | ||
|
||
function linearsolve_dual_solution( | ||
u::Number, partials, dual_type) | ||
return dual_type(u, partials) | ||
end | ||
|
||
function linearsolve_dual_solution( | ||
u::AbstractArray, partials, dual_type) | ||
partials_list = RecursiveArrayTools.VectorOfArray(partials) | ||
return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))), | ||
zip(u, partials_list[i, :] for i in 1:length(partials_list[1]))) | ||
end | ||
|
||
function SciMLBase.init( | ||
prob::DualAbstractLinearProblem, alg::LinearSolve.SciMLLinearSolveAlgorithm, | ||
args...; | ||
alias = LinearAliasSpecifier(), | ||
abstol = LinearSolve.default_tol(real(eltype(prob.b))), | ||
reltol = LinearSolve.default_tol(real(eltype(prob.b))), | ||
maxiters::Int = length(prob.b), | ||
verbose::Bool = false, | ||
Pl = nothing, | ||
Pr = nothing, | ||
assumptions = OperatorAssumptions(issquare(prob.A)), | ||
sensealg = LinearSolveAdjoint(), | ||
kwargs...) | ||
|
||
(; A, b, u0, p) = prob | ||
new_A = nodual_value(A) | ||
new_b = nodual_value(b) | ||
new_u0 = nodual_value(u0) | ||
|
||
∂_A = partial_vals(A) | ||
∂_b = partial_vals(b) | ||
|
||
primal_prob = LinearProblem(new_A, new_b, u0 = new_u0) | ||
#remake(prob; A = new_A, b = new_b, u0 = new_u0) | ||
|
||
if get_dual_type(prob.A) !== nothing | ||
dual_type = get_dual_type(prob.A) | ||
elseif get_dual_type(prob.b) !== nothing | ||
dual_type = get_dual_type(prob.b) | ||
end | ||
|
||
non_partial_cache = init( | ||
primal_prob, alg, args...; alias = alias, abstol = abstol, reltol = reltol, | ||
maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions, | ||
sensealg = sensealg, u0 = new_u0, kwargs...) | ||
return DualLinearCache(non_partial_cache, dual_type, ∂_A, ∂_b) | ||
end | ||
|
||
function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...) | ||
sol, | ||
partials = linearsolve_forwarddiff_solve( | ||
cache::DualLinearCache, cache.alg, args...; kwargs...) | ||
|
||
dual_sol = linearsolve_dual_solution(sol.u, partials, cache.dual_type) | ||
return SciMLBase.build_linear_solution( | ||
cache.alg, dual_sol, sol.resid, cache; sol.retcode, sol.iters, sol.stats | ||
) | ||
end | ||
|
||
# If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache | ||
# Also "forwards" setproperty so that | ||
function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val) | ||
# If the property is A or b, also update it in the LinearCache | ||
if sym === :A || sym === :b || sym === :u | ||
setproperty!(dc.linear_cache, sym, nodual_value(val)) | ||
elseif hasfield(LinearSolve.LinearCache, sym) | ||
setproperty!(dc.linear_cache, sym, val) | ||
end | ||
|
||
# Update the partials if setting A or b | ||
if sym === :A | ||
setfield!(dc, :partials_A, partial_vals(val)) | ||
elseif sym === :b | ||
setfield!(dc, :partials_b, partial_vals(val)) | ||
else | ||
setfield!(dc, sym, val) | ||
end | ||
end | ||
|
||
# "Forwards" getproperty to LinearCache if necessary | ||
function Base.getproperty(dc::DualLinearCache, sym::Symbol) | ||
if hasfield(LinearSolve.LinearCache, sym) | ||
return getproperty(dc.linear_cache, sym) | ||
else | ||
return getfield(dc, sym) | ||
end | ||
end | ||
|
||
|
||
|
||
# Helper functions for Dual numbers | ||
get_dual_type(x::Dual) = typeof(x) | ||
get_dual_type(x::AbstractArray{<:Dual}) = eltype(x) | ||
get_dual_type(x) = nothing | ||
|
||
partial_vals(x::Dual) = ForwardDiff.partials(x) | ||
partial_vals(x::AbstractArray{<:Dual}) = map(ForwardDiff.partials, x) | ||
partial_vals(x) = nothing | ||
|
||
nodual_value(x) = x | ||
nodual_value(x::Dual) = ForwardDiff.value(x) | ||
nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x) | ||
|
||
|
||
function partials_to_list(partial_matrix::Vector) | ||
p = eachindex(first(partial_matrix)) | ||
[[partial[i] for partial in partial_matrix] for i in p] | ||
end | ||
|
||
function partials_to_list(partial_matrix) | ||
p = length(first(partial_matrix)) | ||
m, n = size(partial_matrix) | ||
res_list = fill(zeros(m, n), p) | ||
for k in 1:p | ||
res = zeros(m, n) | ||
for i in 1:m | ||
for j in 1:n | ||
res[i, j] = partial_matrix[i, j][k] | ||
end | ||
end | ||
res_list[k] = res | ||
end | ||
return res_list | ||
end | ||
|
||
|
||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
using LinearSolve | ||
using ForwardDiff | ||
using Test | ||
|
||
function h(p) | ||
(A = [p[1] p[2]+1 p[2]^3; | ||
3*p[1] p[1]+5 p[2] * p[1]-4; | ||
p[2]^2 9*p[1] p[2]], | ||
b = [p[1] + 1, p[2] * 2, p[1]^2]) | ||
end | ||
|
||
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) | ||
|
||
prob = LinearProblem(A, b) | ||
overload_x_p = solve(prob) | ||
backslash_x_p = A \ b | ||
krylov_overload_x_p = solve(prob, KrylovJL_GMRES()) | ||
@test ≈(overload_x_p, backslash_x_p, rtol = 1e-9) | ||
@test ≈(krylov_overload_x_p, backslash_x_p, rtol = 1e-9) | ||
|
||
krylov_prob = LinearProblem(A, b, u0 = rand(3)) | ||
krylov_u0_sol = solve(krylov_prob, KrylovJL_GMRES()) | ||
|
||
@test ≈(krylov_u0_sol, backslash_x_p, rtol = 1e-9) | ||
|
||
|
||
A, _ = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) | ||
backslash_x_p = A \ [6.0, 10.0, 25.0] | ||
prob = LinearProblem(A, [6.0, 10.0, 25.0]) | ||
|
||
@test ≈(solve(prob).u, backslash_x_p, rtol = 1e-9) | ||
@test ≈(solve(prob, KrylovJL_GMRES()).u, backslash_x_p, rtol = 1e-9) | ||
|
||
_, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) | ||
A = [5.0 6.0 125.0; 15.0 10.0 21.0; 25.0 45.0 5.0] | ||
backslash_x_p = A \ b | ||
prob = LinearProblem(A, b) | ||
|
||
@test ≈(solve(prob).u, backslash_x_p, rtol = 1e-9) | ||
@test ≈(solve(prob, KrylovJL_GMRES()).u, backslash_x_p, rtol = 1e-9) | ||
|
||
A, b = h([ForwardDiff.Dual(10.0, 1.0, 0.0), ForwardDiff.Dual(10.0, 0.0, 1.0)]) | ||
|
||
prob = LinearProblem(A, b) | ||
cache = init(prob) | ||
|
||
new_A, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) | ||
cache.A = new_A | ||
cache.b = new_b | ||
|
||
x_p = solve!(cache) | ||
backslash_x_p = new_A \ new_b | ||
|
||
@test ≈(x_p, backslash_x_p, rtol = 1e-9) | ||
|
||
# Just update A | ||
A, b = h([ForwardDiff.Dual(10.0, 1.0, 0.0), ForwardDiff.Dual(10.0, 0.0, 1.0)]) | ||
|
||
prob = LinearProblem(A, b) | ||
cache = init(prob) | ||
|
||
new_A, _ = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) | ||
cache.A = new_A | ||
|
||
x_p = solve!(cache) | ||
backslash_x_p = new_A \ b | ||
|
||
@test ≈(x_p, backslash_x_p, rtol = 1e-9) | ||
|
||
# Just update b | ||
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) | ||
|
||
prob = LinearProblem(A, b) | ||
cache = init(prob) | ||
|
||
_, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) | ||
cache.b = new_b | ||
|
||
x_p = solve!(cache) | ||
backslash_x_p = A \ new_b | ||
|
||
@test ≈(x_p, backslash_x_p, rtol = 1e-9) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.