Skip to content

Commit f829926

Browse files
committed
reuse list
1 parent 0717289 commit f829926

File tree

1 file changed

+6
-62
lines changed

1 file changed

+6
-62
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 6 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -61,36 +61,6 @@ LinearSolve.@concrete mutable struct DualLinearCache
6161
dual_u
6262
end
6363

64-
# function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...)
65-
# # Solve the primal problem
66-
# dual_u0 = copy(cache.linear_cache.u)
67-
# sol = solve!(cache.linear_cache, alg, args...; kwargs...)
68-
# primal_b = copy(cache.linear_cache.b)
69-
# uu = sol.u
70-
71-
# primal_sol = deepcopy(sol)
72-
73-
# # Solves Dual partials separately
74-
# ∂_A = cache.partials_A
75-
# ∂_b = cache.partials_b
76-
77-
# rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b)
78-
79-
# cache.linear_cache.u = dual_u0
80-
# # We can reuse the linear cache, because the same factorization will work for the partials.
81-
# for i in eachindex(rhs_list)
82-
# cache.linear_cache.b = rhs_list[i]
83-
# rhs_list[i] = copy(solve!(cache.linear_cache, alg, args...; kwargs...).u)
84-
# end
85-
86-
# # Reset to the original `b` and `u`, users will expect that `b` doesn't change if they don't tell it to
87-
# cache.linear_cache.b = primal_b
88-
89-
# partial_sols = rhs_list
90-
91-
# primal_sol, partial_sols
92-
# end
93-
9464
function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...)
9565
# Solve the primal problem
9666
dual_u0 = copy(cache.linear_cache.u)
@@ -108,17 +78,16 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
10878

10979
cache.linear_cache.u = dual_u0
11080
# We can reuse the linear cache, because the same factorization will work for the partials.
111-
partial_sols = []
11281
for i in eachindex(rhs_list)
11382
cache.linear_cache.b = rhs_list[i]
114-
# For nested duals, the result of this solve might also be a dual number
115-
# which will be handled recursively by the same mechanism
116-
push!(partial_sols, copy(solve!(cache.linear_cache, alg, args...; kwargs...).u))
83+
rhs_list[i] = copy(solve!(cache.linear_cache, alg, args...; kwargs...).u)
11784
end
11885

119-
# Reset to the original `b` and `u`
86+
# Reset to the original `b` and `u`, users will expect that `b` doesn't change if they don't tell it to
12087
cache.linear_cache.b = primal_b
12188

89+
partial_sols = rhs_list
90+
12291
primal_sol, partial_sols
12392
end
12493

@@ -147,30 +116,6 @@ function xp_linsolve_rhs(
147116
b_list
148117
end
149118

150-
#=
151-
function SciMLBase.solve(prob::DualAbstractLinearProblem, args...; kwargs...)
152-
return solve(prob, nothing, args...; kwargs...)
153-
end
154-
155-
function SciMLBase.solve(prob::DualAbstractLinearProblem, ::Nothing, args...;
156-
assump = OperatorAssumptions(issquare(nodual_value(prob.A))), kwargs...)
157-
# Extract primal values
158-
primal_A = nodual_value(prob.A)
159-
primal_b = nodual_value(prob.b)
160-
161-
# Use the default algorithm selection based on primal values
162-
default_alg = LinearSolve.defaultalg(primal_A, primal_b, assump)
163-
164-
# Solve with the selected algorithm
165-
return solve(prob, default_alg, args...; kwargs...)
166-
end
167-
168-
function SciMLBase.solve(prob::DualAbstractLinearProblem,
169-
alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; kwargs...)
170-
solve!(init(prob, alg, args...; kwargs...))
171-
end
172-
=#
173-
174119
function linearsolve_dual_solution(
175120
u::Number, partials, dual_type)
176121
return dual_type(u, partials)
@@ -252,7 +197,7 @@ function SciMLBase.init(
252197
primal_prob, alg, args...; alias = alias, abstol = abstol, reltol = reltol,
253198
maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
254199
sensealg = sensealg, u0 = new_u0, kwargs...)
255-
return DualLinearCache(non_partial_cache, dual_type, ∂_A, ∂_b, !isnothing(∂_b) ? zero.(∂_b) : ∂_b, A, b, zero.(b))
200+
return DualLinearCache(non_partial_cache, dual_type, ∂_A, ∂_b, !isnothing(∂_b) ? zero.(∂_b) : ∂_b, A, b, zeros(dual_type, length(b)))
256201
end
257202

258203
function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)
@@ -264,14 +209,13 @@ function SciMLBase.solve!(cache::DualLinearCache, alg::SciMLLinearSolveAlgorithm
264209
partials = linearsolve_forwarddiff_solve(
265210
cache::DualLinearCache, cache.alg, args...; kwargs...)
266211
dual_sol = linearsolve_dual_solution(sol.u, partials, cache.dual_type)
267-
212+
Main.@infiltrate
268213
cache.dual_u = dual_sol
269214

270215
return SciMLBase.build_linear_solution(
271216
cache.alg, dual_sol, sol.resid, cache; sol.retcode, sol.iters, sol.stats
272217
)
273218
end
274-
=#
275219

276220
# If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache
277221
function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val)

0 commit comments

Comments
 (0)