Skip to content

Commit 3547ec7

Browse files
committed
rearrange, make sure that dualcache works
1 parent 922f7ec commit 3547ec7

File tree

1 file changed

+123
-61
lines changed

1 file changed

+123
-61
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 123 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,15 @@ const DualAbstractLinearProblem = Union{
3434
DualLinearProblem, DualALinearProblem, DualBLinearProblem}
3535

3636
LinearSolve.@concrete mutable struct DualLinearCache
37-
cache
37+
linear_cache
3838
prob
3939
alg
40-
A
41-
b
4240
partials_A
4341
partials_b
4442
end
4543

4644
function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...)
47-
sol = solve!(cache.cache, alg, args...; kwargs...)
45+
sol = solve!(cache.linear_cache, alg, args...; kwargs...)
4846
uu = sol.u
4947

5048
# Solves Dual partials separately
@@ -53,7 +51,7 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
5351

5452
rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b)
5553

56-
new_A = nodual_value(cache.prob.A)
54+
new_A = nodual_value(cache.A)
5755
partial_prob = LinearProblem(new_A, rhs_list[1])
5856
partial_cache = init(partial_prob, alg, args...; kwargs...)
5957

@@ -67,44 +65,6 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
6765
sol, partial_sols
6866
end
6967

70-
function SciMLBase.solve(prob::DualAbstractLinearProblem, args...; kwargs...)
71-
return solve(prob, nothing, args...; kwargs...)
72-
end
73-
74-
function SciMLBase.solve(prob::DualAbstractLinearProblem, ::Nothing, args...;
75-
assump = OperatorAssumptions(issquare(prob.A)), kwargs...)
76-
return solve(prob, LinearSolve.defaultalg(prob.A, prob.b, assump), args...; kwargs...)
77-
end
78-
79-
function SciMLBase.solve(prob::DualAbstractLinearProblem,
80-
alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; kwargs...)
81-
solve!(init(prob, alg, args...; kwargs...))
82-
end
83-
84-
function linearsolve_dual_solution(
85-
u::Number, partials, dual_type)
86-
return dual_type(u, partials)
87-
end
88-
89-
function linearsolve_dual_solution(
90-
u::AbstractArray, partials, dual_type)
91-
partials_list = RecursiveArrayTools.VectorOfArray(partials)
92-
return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))),
93-
zip(u, partials_list[i, :] for i in 1:length(partials_list[1])))
94-
end
95-
96-
get_dual_type(x::Dual) = typeof(x)
97-
get_dual_type(x::AbstractArray{<:Dual}) = eltype(x)
98-
get_dual_type(x) = nothing
99-
100-
partial_vals(x::Dual) = ForwardDiff.partials(x)
101-
partial_vals(x::AbstractArray{<:Dual}) = map(ForwardDiff.partials, x)
102-
partial_vals(x) = nothing
103-
104-
nodual_value(x) = x
105-
nodual_value(x::Dual) = ForwardDiff.value(x)
106-
nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)
107-
10868
function xp_linsolve_rhs(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}},
10969
∂_b::Union{<:Partials, <:AbstractArray{<:Partials}})
11070
A_list = partials_to_list(∂_A)
@@ -130,25 +90,30 @@ function xp_linsolve_rhs(
13090
b_list
13191
end
13292

133-
function partials_to_list(partial_matrix::Vector)
134-
p = eachindex(first(partial_matrix))
135-
[[partial[i] for partial in partial_matrix] for i in p]
93+
function SciMLBase.solve(prob::DualAbstractLinearProblem, args...; kwargs...)
94+
return solve(prob, nothing, args...; kwargs...)
13695
end
13796

138-
function partials_to_list(partial_matrix)
139-
p = length(first(partial_matrix))
140-
m, n = size(partial_matrix)
141-
res_list = fill(zeros(m, n), p)
142-
for k in 1:p
143-
res = zeros(m, n)
144-
for i in 1:m
145-
for j in 1:n
146-
res[i, j] = partial_matrix[i, j][k]
147-
end
148-
end
149-
res_list[k] = res
150-
end
151-
return res_list
97+
function SciMLBase.solve(prob::DualAbstractLinearProblem, ::Nothing, args...;
98+
assump = OperatorAssumptions(issquare(prob.A)), kwargs...)
99+
return solve(prob, LinearSolve.defaultalg(prob.A, prob.b, assump), args...; kwargs...)
100+
end
101+
102+
function SciMLBase.solve(prob::DualAbstractLinearProblem,
103+
alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; kwargs...)
104+
solve!(init(prob, alg, args...; kwargs...))
105+
end
106+
107+
function linearsolve_dual_solution(
108+
u::Number, partials, dual_type)
109+
return dual_type(u, partials)
110+
end
111+
112+
function linearsolve_dual_solution(
113+
u::AbstractArray, partials, dual_type)
114+
partials_list = RecursiveArrayTools.VectorOfArray(partials)
115+
return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))),
116+
zip(u, partials_list[i, :] for i in 1:length(partials_list[1])))
152117
end
153118

154119
function SciMLBase.init(
@@ -164,6 +129,7 @@ function SciMLBase.init(
164129
assumptions = OperatorAssumptions(issquare(prob.A)),
165130
sensealg = LinearSolveAdjoint(),
166131
kwargs...)
132+
167133
new_A = nodual_value(prob.A)
168134
new_b = nodual_value(prob.b)
169135

@@ -177,7 +143,7 @@ function SciMLBase.init(
177143
maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
178144
sensealg = sensealg, kwargs...)
179145

180-
return DualLinearCache(non_partial_cache, prob, alg, new_A, new_b, ∂_A, ∂_b)
146+
return DualLinearCache(non_partial_cache, prob, alg, ∂_A, ∂_b)
181147
end
182148

183149
function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)
@@ -198,4 +164,100 @@ function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)
198164
)
199165
end
200166

167+
# If setting A or b for DualLinearCache, also set it for the underlying LinearCache
168+
function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val)
169+
# If the property is A or b, also update it in the LinearCache
170+
if sym === :A || sym === :b
171+
if hasproperty(dc, :linear_cache)
172+
setproperty!(dc.linear_cache, sym, nodual_value(val))
173+
end
174+
end
175+
176+
# Update the partials if setting A or b
177+
if sym === :A
178+
setfield!(dc, :partials_A, partial_vals(val))
179+
elseif sym === :b
180+
setfield!(dc, :partials_b, partial_vals(val))
181+
end
182+
183+
return val
184+
end
185+
186+
function Base.getproperty(dc::DualLinearCache, sym::Symbol)
187+
if sym === :A
188+
return dc.linear_cache.A
189+
elseif sym === :b
190+
return dc.linear_cache.b
191+
else
192+
getfield(dc,sym)
193+
end
194+
end
195+
196+
function SciMLBase.reinit!(cache::DualLinearCache;
197+
A = nothing,
198+
b = cache.b,
199+
u = cache.u,
200+
p = nothing,
201+
reuse_precs = false)
202+
(; alg, cacheval, abstol, reltol, maxiters, verbose, assumptions, sensealg) = cache
203+
204+
isfresh = !isnothing(A)
205+
precsisfresh = !reuse_precs && (isfresh || !isnothing(p))
206+
isfresh |= cache.isfresh
207+
precsisfresh |= cache.precsisfresh
208+
209+
A = isnothing(A) ? cache.A : A
210+
b = isnothing(b) ? cache.b : b
211+
u = isnothing(u) ? cache.u : u
212+
p = isnothing(p) ? cache.p : p
213+
Pl = cache.Pl
214+
Pr = cache.Pr
215+
216+
cache.A = A
217+
cache.b = b
218+
cache.u = u
219+
cache.p = p
220+
cache.Pl = Pl
221+
cache.Pr = Pr
222+
cache.isfresh = true
223+
cache.precsisfresh = precsisfresh
224+
nothing
225+
end
226+
227+
# Helper functions for Dual numbers
228+
get_dual_type(x::Dual) = typeof(x)
229+
get_dual_type(x::AbstractArray{<:Dual}) = eltype(x)
230+
get_dual_type(x) = nothing
231+
232+
partial_vals(x::Dual) = ForwardDiff.partials(x)
233+
partial_vals(x::AbstractArray{<:Dual}) = map(ForwardDiff.partials, x)
234+
partial_vals(x) = nothing
235+
236+
nodual_value(x) = x
237+
nodual_value(x::Dual) = ForwardDiff.value(x)
238+
nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)
239+
240+
241+
function partials_to_list(partial_matrix::Vector)
242+
p = eachindex(first(partial_matrix))
243+
[[partial[i] for partial in partial_matrix] for i in p]
244+
end
245+
246+
function partials_to_list(partial_matrix)
247+
p = length(first(partial_matrix))
248+
m, n = size(partial_matrix)
249+
res_list = fill(zeros(m, n), p)
250+
for k in 1:p
251+
res = zeros(m, n)
252+
for i in 1:m
253+
for j in 1:n
254+
res[i, j] = partial_matrix[i, j][k]
255+
end
256+
end
257+
res_list[k] = res
258+
end
259+
return res_list
260+
end
261+
262+
201263
end

0 commit comments

Comments
 (0)