Skip to content

Commit c57c9d8

Browse files
committed
add LinearSolveForwardDiffExt.jl
1 parent a65fb46 commit c57c9d8

File tree

1 file changed

+123
-0
lines changed

1 file changed

+123
-0
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
module LinearSolveForwardDiffExt
2+
3+
const DualLinearProblem = LinearProblem{
4+
<:Union{Number, <:AbstractArray}, iip,
5+
<:Union{<:Dual{T,V,P}, <:AbstractArray{<:Dual{T,V,P}}},
6+
<:Union{<:Dual{T,V,P}, <:AbstractArray{<:Dual{T,V,P}}},
7+
<:Union{Number, <:AbstractArray}
8+
} where {iip, T, V}
9+
10+
11+
const DualALinearProblem = LinearProblem{
12+
<:Union{Number, <:AbstractArray},
13+
iip,
14+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}},
15+
<:Union{Number, <:AbstractArray},
16+
<:Union{Number, <:AbstractArray}
17+
}
18+
19+
const DualBLinearProblem = LinearProblem{
20+
<:Union{Number, <:AbstractArray},
21+
iip,
22+
<:Union{Number, <:AbstractArray},
23+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}},
24+
<:Union{Number, <:AbstractArray}
25+
}
26+
27+
const DualAbstractLinearProblem = Union{DualLinearProblem, DualALinearProblem, DualBLinearProblem}
28+
29+
30+
function linearsolve_forwarddiff_solve(prob::LinearProblem, alg, args...; kwargs...)
31+
new_A = nodual_value(prob.A)
32+
new_b = nodual_value(prob.b)
33+
34+
newprob = remake(prob; A = new_A, b = new_b)
35+
36+
sol = solve(newprob, alg, args...; kwargs...)
37+
uu = sol.u
38+
39+
∂_A = partial_vals(A)
40+
∂_b = partial_vals(b)
41+
42+
43+
44+
if uu isa Number
45+
46+
else
47+
48+
end
49+
50+
end
51+
52+
53+
54+
partial_vals(x::Dual) = ForwardDiff.partials(x)
55+
partial_vals(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)
56+
partial_vals(x) = nothing
57+
58+
nodual_value(x) = x
59+
nodual_value(x::Dual) = ForwardDiff.value(x)
60+
nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)
61+
62+
63+
function x_p_linsolve(new_A, uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}}, ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}})
64+
A_list = partials_to_list(∂_A)
65+
b_list = partials_to_list(∂_b)
66+
67+
Auu = [A*uu for A in A_list]
68+
69+
linsol_rhs = reduce(hcat, b_list .- Auu)
70+
71+
new_A \ linsol_rhs
72+
end
73+
74+
function x_p_linsolve(new_A, uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}}, ∂_b::Nothing)
75+
A_list = partials_to_list(∂_A)
76+
77+
Auu = [A*uu for A in A_list]
78+
79+
linsol_rhs = reduce(hcat, Auu)
80+
81+
new_A \ linsol_rhs
82+
end
83+
84+
function x_p_linsolve(new_A, uu, ∂_A::Nothing, ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}})
85+
b_list = partials_to_list(∂_b)
86+
87+
linsol_rhs = reduce(hcat, b_list)
88+
89+
new_A \ linsol_rhs
90+
end
91+
92+
93+
94+
function partials_to_list(partial_matrix::Vector)
95+
p = eachindex(first(partial_matrix))
96+
[[partial[i] for partial in partial_matrix] for i in p]
97+
end
98+
99+
function partials_to_list(partial_matrix)
100+
p = length(first(partial_matrix))
101+
m,n = size(partial_matrix)
102+
res_list = fill(zeros(m,n),p)
103+
for k in 1:p
104+
res = zeros(m,n)
105+
for i in 1:m
106+
for j in 1:n
107+
res[i,j] = partial_matrix[i,j][k]
108+
end
109+
end
110+
res_list[k] = res
111+
end
112+
return res_list
113+
end
114+
115+
116+
117+
118+
119+
120+
121+
122+
123+

0 commit comments

Comments
 (0)