Skip to content

Commit c154d25

Browse files
committed
fix up the linear dual solution
1 parent 677570f commit c154d25

File tree

1 file changed

+60
-25
lines changed

1 file changed

+60
-25
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 60 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,91 @@
11
module LinearSolveForwardDiffExt
22

33
const DualLinearProblem = LinearProblem{
4-
<:Union{Number, <:AbstractArray}, iip,
5-
<:Union{<:Dual{T,V,P}, <:AbstractArray{<:Dual{T,V,P}}},
6-
<:Union{<:Dual{T,V,P}, <:AbstractArray{<:Dual{T,V,P}}},
7-
<:Union{Number, <:AbstractArray}
8-
} where {iip, T, V}
4+
<:Union{Number,<:AbstractArray, Nothing},iip,
5+
<:Union{<:Dual{T,V,P},<:AbstractArray{<:Dual{T,V,P}}},
6+
<:Union{<:Dual{T,V,P},<:AbstractArray{<:Dual{T,V,P}}},
7+
<:Union{Number,<:AbstractArray, SciMLBase.NullParameters}
8+
} where {iip, T, V, P}
99

1010

1111
const DualALinearProblem = LinearProblem{
12-
<:Union{Number, <:AbstractArray},
12+
<:Union{Number,<:AbstractArray, Nothing},
1313
iip,
14-
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}},
15-
<:Union{Number, <:AbstractArray},
16-
<:Union{Number, <:AbstractArray}
17-
}
14+
<:Union{<:Dual{T,V,P},<:AbstractArray{<:Dual{T,V,P}}},
15+
<:Union{Number,<:AbstractArray},
16+
<:Union{Number,<:AbstractArray, SciMLBase.NullParameters}
17+
} where {iip, T, V, P}
1818

1919
const DualBLinearProblem = LinearProblem{
20-
<:Union{Number, <:AbstractArray},
20+
<:Union{Number,<:AbstractArray, Nothing},
2121
iip,
22-
<:Union{Number, <:AbstractArray},
23-
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}},
24-
<:Union{Number, <:AbstractArray}
25-
}
22+
<:Union{Number,<:AbstractArray},
23+
<:Union{<:Dual{T,V,P},<:AbstractArray{<:Dual{T,V,P}}},
24+
<:Union{Number,<:AbstractArray, SciMLBase.NullParameters}
25+
} where {iip, T, V, P}
2626

2727
const DualAbstractLinearProblem = Union{DualLinearProblem, DualALinearProblem, DualBLinearProblem}
2828

29-
3029
function linearsolve_forwarddiff_solve(prob::LinearProblem, alg, args...; kwargs...)
3130
new_A = nodual_value(prob.A)
3231
new_b = nodual_value(prob.b)
3332

34-
newprob = remake(prob; A = new_A, b = new_b)
33+
newprob = remake(prob; A=new_A, b=new_b)
3534

3635
sol = solve(newprob, alg, args...; kwargs...)
3736
uu = sol.u
3837

38+
39+
# Solves Dual partials separately
3940
∂_A = partial_vals(A)
4041
∂_b = partial_vals(b)
4142

42-
rhs = xp_linsolve_rhs(uu, ∂_A, ∂_b)
43+
rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b)
4344

44-
partial_prob = remake(newprob, b = rhs)
45-
partial_sol = solve(partial_prob, alg, args...; kwargs...)
45+
partial_sols = map(rhs_list) do rhs
46+
partial_prob = remake(newprob, b=rhs)
47+
solve(partial_prob, alg, args...; kwargs...).u
48+
end
4649

47-
sol, partial_sol
50+
sol, partial_sols
4851
end
4952

53+
function __solve(prob::DualAbstractLinearProblem, alg, args...; kwargs...)
54+
sol, partials = linearsolve_forwarddiff_solve(
55+
prob, alg, args...; kwargs...
56+
)
57+
58+
if get_dual_type(prob.A) !== nothing
59+
dual_type = get_dual_type(prob.A)
60+
elseif get_dual_type(prob.b) !== nothing
61+
dual_type = get_dual_type(prob.b)
62+
return sol
63+
end
64+
65+
linearsolve_dual_solution(sol.u, partials, dual_type)
66+
67+
end
68+
69+
70+
function linearsolve_dual_solution(
71+
u::Number, partials, dual_type)
72+
return dual_type(u, partials)
73+
end
74+
75+
function linearsolve_dual_solution(
76+
u::AbstractArray, partials, dual_type)
77+
partials_list = RecursiveArrayTools.VectorOfArray(partials)
78+
return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))), zip(u, partials_list[i, :] for i in 1:length(partials_list[1])))
79+
end
80+
81+
82+
get_dual_type(x::Dual) = typeof(x)
83+
get_dual_type(x::AbstractArray{<:Dual}) = eltype(x)
84+
get_dual_type(x) = nothing
5085

5186

5287
partial_vals(x::Dual) = ForwardDiff.partials(x)
53-
partial_vals(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)
88+
partial_vals(x::AbstractArray{<:Dual}) = map(ForwardDiff.partials, x)
5489
partial_vals(x) = nothing
5590

5691
nodual_value(x) = x
@@ -64,21 +99,21 @@ function xp_linsolve_rhs(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials
6499

65100
Auu = [A*uu for A in A_list]
66101

67-
reduce(hcat, b_list .- Auu)
102+
b_list .- Auu
68103
end
69104

70105
function xp_linsolve_rhs(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}}, ∂_b::Nothing)
71106
A_list = partials_to_list(∂_A)
72107

73108
Auu = [A*uu for A in A_list]
74109

75-
reduce(hcat, Auu)
110+
Auu
76111
end
77112

78113
function xp_linsolve_rhs(uu, ∂_A::Nothing, ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}})
79114
b_list = partials_to_list(∂_b)
80115

81-
reduce(hcat, b_list)
116+
b_list
82117
end
83118

84119

0 commit comments

Comments
 (0)