@@ -61,36 +61,6 @@ LinearSolve.@concrete mutable struct DualLinearCache
61
61
dual_u
62
62
end
63
63
64
- # function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...)
65
- # # Solve the primal problem
66
- # dual_u0 = copy(cache.linear_cache.u)
67
- # sol = solve!(cache.linear_cache, alg, args...; kwargs...)
68
- # primal_b = copy(cache.linear_cache.b)
69
- # uu = sol.u
70
-
71
- # primal_sol = deepcopy(sol)
72
-
73
- # # Solves Dual partials separately
74
- # ∂_A = cache.partials_A
75
- # ∂_b = cache.partials_b
76
-
77
- # rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b)
78
-
79
- # cache.linear_cache.u = dual_u0
80
- # # We can reuse the linear cache, because the same factorization will work for the partials.
81
- # for i in eachindex(rhs_list)
82
- # cache.linear_cache.b = rhs_list[i]
83
- # rhs_list[i] = copy(solve!(cache.linear_cache, alg, args...; kwargs...).u)
84
- # end
85
-
86
- # # Reset to the original `b` and `u`, users will expect that `b` doesn't change if they don't tell it to
87
- # cache.linear_cache.b = primal_b
88
-
89
- # partial_sols = rhs_list
90
-
91
- # primal_sol, partial_sols
92
- # end
93
-
94
64
function linearsolve_forwarddiff_solve (cache:: DualLinearCache , alg, args... ; kwargs... )
95
65
# Solve the primal problem
96
66
dual_u0 = copy (cache. linear_cache. u)
@@ -108,17 +78,16 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
108
78
109
79
cache. linear_cache. u = dual_u0
110
80
# We can reuse the linear cache, because the same factorization will work for the partials.
111
- partial_sols = []
112
81
for i in eachindex (rhs_list)
113
82
cache. linear_cache. b = rhs_list[i]
114
- # For nested duals, the result of this solve might also be a dual number
115
- # which will be handled recursively by the same mechanism
116
- push! (partial_sols, copy (solve! (cache. linear_cache, alg, args... ; kwargs... ). u))
83
+ rhs_list[i] = copy (solve! (cache. linear_cache, alg, args... ; kwargs... ). u)
117
84
end
118
85
119
- # Reset to the original `b` and `u`
86
+ # Reset to the original `b` and `u`, users will expect that `b` doesn't change if they don't tell it to
120
87
cache. linear_cache. b = primal_b
121
88
89
+ partial_sols = rhs_list
90
+
122
91
primal_sol, partial_sols
123
92
end
124
93
@@ -147,30 +116,6 @@ function xp_linsolve_rhs(
147
116
b_list
148
117
end
149
118
150
- #=
151
- function SciMLBase.solve(prob::DualAbstractLinearProblem, args...; kwargs...)
152
- return solve(prob, nothing, args...; kwargs...)
153
- end
154
-
155
- function SciMLBase.solve(prob::DualAbstractLinearProblem, ::Nothing, args...;
156
- assump = OperatorAssumptions(issquare(nodual_value(prob.A))), kwargs...)
157
- # Extract primal values
158
- primal_A = nodual_value(prob.A)
159
- primal_b = nodual_value(prob.b)
160
-
161
- # Use the default algorithm selection based on primal values
162
- default_alg = LinearSolve.defaultalg(primal_A, primal_b, assump)
163
-
164
- # Solve with the selected algorithm
165
- return solve(prob, default_alg, args...; kwargs...)
166
- end
167
-
168
- function SciMLBase.solve(prob::DualAbstractLinearProblem,
169
- alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; kwargs...)
170
- solve!(init(prob, alg, args...; kwargs...))
171
- end
172
- =#
173
-
174
119
function linearsolve_dual_solution (
175
120
u:: Number , partials, dual_type)
176
121
return dual_type (u, partials)
@@ -252,7 +197,7 @@ function SciMLBase.init(
252
197
primal_prob, alg, args... ; alias = alias, abstol = abstol, reltol = reltol,
253
198
maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
254
199
sensealg = sensealg, u0 = new_u0, kwargs... )
255
- return DualLinearCache (non_partial_cache, dual_type, ∂_A, ∂_b, ! isnothing (∂_b) ? zero .(∂_b) : ∂_b, A, b, zero .(b ))
200
+ return DualLinearCache (non_partial_cache, dual_type, ∂_A, ∂_b, ! isnothing (∂_b) ? zero .(∂_b) : ∂_b, A, b, zeros (dual_type, length (b) ))
256
201
end
257
202
258
203
function SciMLBase. solve! (cache:: DualLinearCache , args... ; kwargs... )
@@ -264,14 +209,13 @@ function SciMLBase.solve!(cache::DualLinearCache, alg::SciMLLinearSolveAlgorithm
264
209
partials = linearsolve_forwarddiff_solve (
265
210
cache:: DualLinearCache , cache. alg, args... ; kwargs... )
266
211
dual_sol = linearsolve_dual_solution (sol. u, partials, cache. dual_type)
267
-
212
+ Main . @infiltrate
268
213
cache. dual_u = dual_sol
269
214
270
215
return SciMLBase. build_linear_solution (
271
216
cache. alg, dual_sol, sol. resid, cache; sol. retcode, sol. iters, sol. stats
272
217
)
273
218
end
274
- = #
275
219
276
220
# If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache
277
221
function Base. setproperty! (dc:: DualLinearCache , sym:: Symbol , val)
0 commit comments