Skip to content

Commit 6f33486

Browse files
committed
forward steproperty and getproperty more
1 parent 3565d9b commit 6f33486

File tree

1 file changed

+20
-21
lines changed

1 file changed

+20
-21
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@ const DualAbstractLinearProblem = Union{
3535

3636
LinearSolve.@concrete mutable struct DualLinearCache
3737
linear_cache
38-
prob
39-
alg
38+
dual_type
4039
dual_u0
4140
partials_A
4241
partials_b
@@ -147,41 +146,42 @@ function SciMLBase.init(
147146
∂_b = partial_vals(b)
148147
dual_u0 = partial_vals(u0)
149148

150-
newprob = LinearProblem(new_A, new_b, u0 = new_u0)
149+
primal_prob = LinearProblem(new_A, new_b, u0 = new_u0)
151150
#remake(prob; A = new_A, b = new_b, u0 = new_u0)
152151

152+
if get_dual_type(prob.A) !== nothing
153+
dual_type = get_dual_type(prob.A)
154+
elseif get_dual_type(prob.b) !== nothing
155+
dual_type = get_dual_type(prob.b)
156+
end
157+
153158
non_partial_cache = init(
154-
newprob, alg, args...; alias = alias, abstol = abstol, reltol = reltol,
159+
primal_prob, alg, args...; alias = alias, abstol = abstol, reltol = reltol,
155160
maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
156161
sensealg = sensealg, u0 = new_u0, kwargs...)
157-
return DualLinearCache(non_partial_cache, prob, alg, dual_u0, ∂_A, ∂_b)
162+
return DualLinearCache(non_partial_cache, dual_type, dual_u0, ∂_A, ∂_b)
158163
end
159164

160165
function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)
161166
sol,
162167
partials = linearsolve_forwarddiff_solve(
163168
cache::DualLinearCache, cache.alg, args...; kwargs...)
164169

165-
if get_dual_type(cache.prob.A) !== nothing
166-
dual_type = get_dual_type(cache.prob.A)
167-
elseif get_dual_type(cache.prob.b) !== nothing
168-
dual_type = get_dual_type(cache.prob.b)
169-
end
170-
171-
dual_sol = linearsolve_dual_solution(sol.u, partials, dual_type)
170+
dual_sol = linearsolve_dual_solution(sol.u, partials, cache.dual_type)
172171

173172
return SciMLBase.build_linear_solution(
174173
cache.alg, dual_sol, sol.resid, sol.cache; sol.retcode, sol.iters, sol.stats
175174
)
176175
end
177176

178-
# If setting A or b for DualLinearCache, also set it for the underlying LinearCache
177+
# If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache
178+
# Also "forwards" setproperty so that
179179
function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val)
180180
# If the property is A or b, also update it in the LinearCache
181181
if sym === :A || sym === :b || sym === :u
182-
if hasproperty(dc, :linear_cache)
183-
setproperty!(dc.linear_cache, sym, nodual_value(val))
184-
end
182+
setproperty!(dc.linear_cache, sym, nodual_value(val))
183+
elseif hasfield(LinearSolve.LinearCache, sym)
184+
setproperty!(dc.linear_cache, sym, val)
185185
end
186186

187187
# Update the partials if setting A or b
@@ -194,13 +194,12 @@ function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val)
194194
end
195195
end
196196

197+
# "Forwards" getproperty to LinearCache if necessary
197198
function Base.getproperty(dc::DualLinearCache, sym::Symbol)
198-
if sym === :A
199-
return dc.linear_cache.A
200-
elseif sym === :b
201-
return dc.linear_cache.b
199+
if hasfield(LinearSolve.LinearCache, sym)
200+
return getproperty(dc.linear_cache, sym)
202201
else
203-
getfield(dc,sym)
202+
return getfield(dc, sym)
204203
end
205204
end
206205

0 commit comments

Comments
 (0)