@@ -34,17 +34,15 @@ const DualAbstractLinearProblem = Union{
34
34
DualLinearProblem, DualALinearProblem, DualBLinearProblem}
35
35
36
36
LinearSolve. @concrete mutable struct DualLinearCache
37
- cache
37
+ linear_cache
38
38
prob
39
39
alg
40
- A
41
- b
42
40
partials_A
43
41
partials_b
44
42
end
45
43
46
44
function linearsolve_forwarddiff_solve (cache:: DualLinearCache , alg, args... ; kwargs... )
47
- sol = solve! (cache. cache , alg, args... ; kwargs... )
45
+ sol = solve! (cache. linear_cache , alg, args... ; kwargs... )
48
46
uu = sol. u
49
47
50
48
# Solves Dual partials separately
@@ -53,7 +51,7 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
53
51
54
52
rhs_list = xp_linsolve_rhs (uu, ∂_A, ∂_b)
55
53
56
- new_A = nodual_value (cache. prob . A)
54
+ new_A = nodual_value (cache. A)
57
55
partial_prob = LinearProblem (new_A, rhs_list[1 ])
58
56
partial_cache = init (partial_prob, alg, args... ; kwargs... )
59
57
@@ -67,44 +65,6 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
67
65
sol, partial_sols
68
66
end
69
67
70
- function SciMLBase. solve (prob:: DualAbstractLinearProblem , args... ; kwargs... )
71
- return solve (prob, nothing , args... ; kwargs... )
72
- end
73
-
74
- function SciMLBase. solve (prob:: DualAbstractLinearProblem , :: Nothing , args... ;
75
- assump = OperatorAssumptions (issquare (prob. A)), kwargs... )
76
- return solve (prob, LinearSolve. defaultalg (prob. A, prob. b, assump), args... ; kwargs... )
77
- end
78
-
79
- function SciMLBase. solve (prob:: DualAbstractLinearProblem ,
80
- alg:: LinearSolve.SciMLLinearSolveAlgorithm , args... ; kwargs... )
81
- solve! (init (prob, alg, args... ; kwargs... ))
82
- end
83
-
84
- function linearsolve_dual_solution (
85
- u:: Number , partials, dual_type)
86
- return dual_type (u, partials)
87
- end
88
-
89
- function linearsolve_dual_solution (
90
- u:: AbstractArray , partials, dual_type)
91
- partials_list = RecursiveArrayTools. VectorOfArray (partials)
92
- return map (((uᵢ, pᵢ),) -> dual_type (uᵢ, Partials (Tuple (pᵢ))),
93
- zip (u, partials_list[i, :] for i in 1 : length (partials_list[1 ])))
94
- end
95
-
96
- get_dual_type (x:: Dual ) = typeof (x)
97
- get_dual_type (x:: AbstractArray{<:Dual} ) = eltype (x)
98
- get_dual_type (x) = nothing
99
-
100
- partial_vals (x:: Dual ) = ForwardDiff. partials (x)
101
- partial_vals (x:: AbstractArray{<:Dual} ) = map (ForwardDiff. partials, x)
102
- partial_vals (x) = nothing
103
-
104
- nodual_value (x) = x
105
- nodual_value (x:: Dual ) = ForwardDiff. value (x)
106
- nodual_value (x:: AbstractArray{<:Dual} ) = map (ForwardDiff. value, x)
107
-
108
68
function xp_linsolve_rhs (uu, ∂_A:: Union{<:Partials, <:AbstractArray{<:Partials}} ,
109
69
∂_b:: Union{<:Partials, <:AbstractArray{<:Partials}} )
110
70
A_list = partials_to_list (∂_A)
@@ -130,25 +90,30 @@ function xp_linsolve_rhs(
130
90
b_list
131
91
end
132
92
133
- function partials_to_list (partial_matrix:: Vector )
134
- p = eachindex (first (partial_matrix))
135
- [[partial[i] for partial in partial_matrix] for i in p]
93
+ function SciMLBase. solve (prob:: DualAbstractLinearProblem , args... ; kwargs... )
94
+ return solve (prob, nothing , args... ; kwargs... )
136
95
end
137
96
138
- function partials_to_list (partial_matrix)
139
- p = length (first (partial_matrix))
140
- m, n = size (partial_matrix)
141
- res_list = fill (zeros (m, n), p)
142
- for k in 1 : p
143
- res = zeros (m, n)
144
- for i in 1 : m
145
- for j in 1 : n
146
- res[i, j] = partial_matrix[i, j][k]
147
- end
148
- end
149
- res_list[k] = res
150
- end
151
- return res_list
97
+ function SciMLBase. solve (prob:: DualAbstractLinearProblem , :: Nothing , args... ;
98
+ assump = OperatorAssumptions (issquare (prob. A)), kwargs... )
99
+ return solve (prob, LinearSolve. defaultalg (prob. A, prob. b, assump), args... ; kwargs... )
100
+ end
101
+
102
+ function SciMLBase. solve (prob:: DualAbstractLinearProblem ,
103
+ alg:: LinearSolve.SciMLLinearSolveAlgorithm , args... ; kwargs... )
104
+ solve! (init (prob, alg, args... ; kwargs... ))
105
+ end
106
+
107
+ function linearsolve_dual_solution (
108
+ u:: Number , partials, dual_type)
109
+ return dual_type (u, partials)
110
+ end
111
+
112
+ function linearsolve_dual_solution (
113
+ u:: AbstractArray , partials, dual_type)
114
+ partials_list = RecursiveArrayTools. VectorOfArray (partials)
115
+ return map (((uᵢ, pᵢ),) -> dual_type (uᵢ, Partials (Tuple (pᵢ))),
116
+ zip (u, partials_list[i, :] for i in 1 : length (partials_list[1 ])))
152
117
end
153
118
154
119
function SciMLBase. init (
@@ -164,6 +129,7 @@ function SciMLBase.init(
164
129
assumptions = OperatorAssumptions (issquare (prob. A)),
165
130
sensealg = LinearSolveAdjoint (),
166
131
kwargs... )
132
+
167
133
new_A = nodual_value (prob. A)
168
134
new_b = nodual_value (prob. b)
169
135
@@ -177,7 +143,7 @@ function SciMLBase.init(
177
143
maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
178
144
sensealg = sensealg, kwargs... )
179
145
180
- return DualLinearCache (non_partial_cache, prob, alg, new_A, new_b, ∂_A, ∂_b)
146
+ return DualLinearCache (non_partial_cache, prob, alg, ∂_A, ∂_b)
181
147
end
182
148
183
149
function SciMLBase. solve! (cache:: DualLinearCache , args... ; kwargs... )
@@ -198,4 +164,100 @@ function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)
198
164
)
199
165
end
200
166
167
+ # If setting A or b for DualLinearCache, also set it for the underlying LinearCache
168
+ function Base. setproperty! (dc:: DualLinearCache , sym:: Symbol , val)
169
+ # If the property is A or b, also update it in the LinearCache
170
+ if sym === :A || sym === :b
171
+ if hasproperty (dc, :linear_cache )
172
+ setproperty! (dc. linear_cache, sym, nodual_value (val))
173
+ end
174
+ end
175
+
176
+ # Update the partials if setting A or b
177
+ if sym === :A
178
+ setfield! (dc, :partials_A , partial_vals (val))
179
+ elseif sym === :b
180
+ setfield! (dc, :partials_b , partial_vals (val))
181
+ end
182
+
183
+ return val
184
+ end
185
+
186
+ function Base. getproperty (dc:: DualLinearCache , sym:: Symbol )
187
+ if sym === :A
188
+ return dc. linear_cache. A
189
+ elseif sym === :b
190
+ return dc. linear_cache. b
191
+ else
192
+ getfield (dc,sym)
193
+ end
194
+ end
195
+
196
+ function SciMLBase. reinit! (cache:: DualLinearCache ;
197
+ A = nothing ,
198
+ b = cache. b,
199
+ u = cache. u,
200
+ p = nothing ,
201
+ reuse_precs = false )
202
+ (; alg, cacheval, abstol, reltol, maxiters, verbose, assumptions, sensealg) = cache
203
+
204
+ isfresh = ! isnothing (A)
205
+ precsisfresh = ! reuse_precs && (isfresh || ! isnothing (p))
206
+ isfresh |= cache. isfresh
207
+ precsisfresh |= cache. precsisfresh
208
+
209
+ A = isnothing (A) ? cache. A : A
210
+ b = isnothing (b) ? cache. b : b
211
+ u = isnothing (u) ? cache. u : u
212
+ p = isnothing (p) ? cache. p : p
213
+ Pl = cache. Pl
214
+ Pr = cache. Pr
215
+
216
+ cache. A = A
217
+ cache. b = b
218
+ cache. u = u
219
+ cache. p = p
220
+ cache. Pl = Pl
221
+ cache. Pr = Pr
222
+ cache. isfresh = true
223
+ cache. precsisfresh = precsisfresh
224
+ nothing
225
+ end
226
+
227
+ # Helper functions for Dual numbers
228
+ get_dual_type (x:: Dual ) = typeof (x)
229
+ get_dual_type (x:: AbstractArray{<:Dual} ) = eltype (x)
230
+ get_dual_type (x) = nothing
231
+
232
+ partial_vals (x:: Dual ) = ForwardDiff. partials (x)
233
+ partial_vals (x:: AbstractArray{<:Dual} ) = map (ForwardDiff. partials, x)
234
+ partial_vals (x) = nothing
235
+
236
+ nodual_value (x) = x
237
+ nodual_value (x:: Dual ) = ForwardDiff. value (x)
238
+ nodual_value (x:: AbstractArray{<:Dual} ) = map (ForwardDiff. value, x)
239
+
240
+
241
+ function partials_to_list (partial_matrix:: Vector )
242
+ p = eachindex (first (partial_matrix))
243
+ [[partial[i] for partial in partial_matrix] for i in p]
244
+ end
245
+
246
+ function partials_to_list (partial_matrix)
247
+ p = length (first (partial_matrix))
248
+ m, n = size (partial_matrix)
249
+ res_list = fill (zeros (m, n), p)
250
+ for k in 1 : p
251
+ res = zeros (m, n)
252
+ for i in 1 : m
253
+ for j in 1 : n
254
+ res[i, j] = partial_matrix[i, j][k]
255
+ end
256
+ end
257
+ res_list[k] = res
258
+ end
259
+ return res_list
260
+ end
261
+
262
+
201
263
end
0 commit comments