1
- module LinearSolveForwardDiffExt
1
+ module LinearSolveForwardDiffExt
2
2
3
3
using LinearSolve
4
4
using ForwardDiff
@@ -7,50 +7,49 @@ using SciMLBase
7
7
using RecursiveArrayTools
8
8
9
9
const DualLinearProblem = LinearProblem{
10
- <: Union{Number,<:AbstractArray, Nothing} ,iip,
11
- <: Union{<:Dual{T,V, P},<:AbstractArray{<:Dual{T,V, P}}} ,
12
- <: Union{<:Dual{T,V, P},<:AbstractArray{<:Dual{T,V, P}}} ,
13
- <: Union{Number,<:AbstractArray, SciMLBase.NullParameters}
10
+ <: Union{Number, <:AbstractArray, Nothing} , iip,
11
+ <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} ,
12
+ <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} ,
13
+ <: Union{Number, <:AbstractArray, SciMLBase.NullParameters}
14
14
} where {iip, T, V, P}
15
15
16
-
17
16
const DualALinearProblem = LinearProblem{
18
- <: Union{Number,<:AbstractArray, Nothing} ,
17
+ <: Union{Number, <:AbstractArray, Nothing} ,
19
18
iip,
20
- <: Union{<:Dual{T,V, P},<:AbstractArray{<:Dual{T,V, P}}} ,
21
- <: Union{Number,<:AbstractArray} ,
22
- <: Union{Number,<:AbstractArray, SciMLBase.NullParameters}
19
+ <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} ,
20
+ <: Union{Number, <:AbstractArray} ,
21
+ <: Union{Number, <:AbstractArray, SciMLBase.NullParameters}
23
22
} where {iip, T, V, P}
24
23
25
24
const DualBLinearProblem = LinearProblem{
26
- <: Union{Number,<:AbstractArray, Nothing} ,
25
+ <: Union{Number, <:AbstractArray, Nothing} ,
27
26
iip,
28
- <: Union{Number,<:AbstractArray} ,
29
- <: Union{<:Dual{T,V, P},<:AbstractArray{<:Dual{T,V, P}}} ,
30
- <: Union{Number,<:AbstractArray, SciMLBase.NullParameters}
27
+ <: Union{Number, <:AbstractArray} ,
28
+ <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} ,
29
+ <: Union{Number, <:AbstractArray, SciMLBase.NullParameters}
31
30
} where {iip, T, V, P}
32
31
33
- const DualAbstractLinearProblem = Union{DualLinearProblem, DualALinearProblem, DualBLinearProblem}
32
+ const DualAbstractLinearProblem = Union{
33
+ DualLinearProblem, DualALinearProblem, DualBLinearProblem}
34
34
35
35
function linearsolve_forwarddiff_solve (prob:: LinearProblem , alg, args... ; kwargs... )
36
36
@info " here!"
37
37
new_A = nodual_value (prob. A)
38
38
new_b = nodual_value (prob. b)
39
39
40
- newprob = remake (prob; A= new_A, b= new_b)
40
+ newprob = remake (prob; A = new_A, b = new_b)
41
41
42
42
sol = solve (newprob, alg, args... ; kwargs... )
43
43
uu = sol. u
44
44
45
-
46
45
# Solves Dual partials separately
47
46
∂_A = partial_vals (prob. A)
48
47
∂_b = partial_vals (prob. b)
49
48
50
49
rhs_list = xp_linsolve_rhs (uu, ∂_A, ∂_b)
51
50
52
51
partial_sols = map (rhs_list) do rhs
53
- partial_prob = remake (newprob, b= rhs)
52
+ partial_prob = remake (newprob, b = rhs)
54
53
solve (partial_prob, alg, args... ; kwargs... ). u
55
54
end
56
55
@@ -66,7 +65,8 @@ function SciMLBase.solve(prob::DualAbstractLinearProblem, ::Nothing, args...;
66
65
return solve (prob, LinearSolve. defaultalg (prob. A, prob. b, assump), args... ; kwargs... )
67
66
end
68
67
69
- function SciMLBase. solve (prob:: DualAbstractLinearProblem , alg:: LinearSolve.SciMLLinearSolveAlgorithm , args... ; kwargs... )
68
+ function SciMLBase. solve (prob:: DualAbstractLinearProblem ,
69
+ alg:: LinearSolve.SciMLLinearSolveAlgorithm , args... ; kwargs... )
70
70
sol, partials = linearsolve_forwarddiff_solve (
71
71
prob, alg, args... ; kwargs...
72
72
)
@@ -82,28 +82,24 @@ function SciMLBase.solve(prob::DualAbstractLinearProblem, alg::LinearSolve.SciML
82
82
return SciMLBase. build_linear_solution (
83
83
alg, dual_sol, sol. resid, sol. cache; sol. retcode, sol. iters, sol. stats
84
84
)
85
-
86
-
87
85
end
88
86
89
-
90
87
function linearsolve_dual_solution (
91
- u:: Number , partials, dual_type)
88
+ u:: Number , partials, dual_type)
92
89
return dual_type (u, partials)
93
90
end
94
91
95
92
function linearsolve_dual_solution (
96
- u:: AbstractArray , partials, dual_type)
93
+ u:: AbstractArray , partials, dual_type)
97
94
partials_list = RecursiveArrayTools. VectorOfArray (partials)
98
- return map (((uᵢ, pᵢ),) -> dual_type (uᵢ, Partials (Tuple (pᵢ))), zip (u, partials_list[i, :] for i in 1 : length (partials_list[1 ])))
95
+ return map (((uᵢ, pᵢ),) -> dual_type (uᵢ, Partials (Tuple (pᵢ))),
96
+ zip (u, partials_list[i, :] for i in 1 : length (partials_list[1 ])))
99
97
end
100
98
101
-
102
99
get_dual_type (x:: Dual ) = typeof (x)
103
100
get_dual_type (x:: AbstractArray{<:Dual} ) = eltype (x)
104
101
get_dual_type (x) = nothing
105
102
106
-
107
103
partial_vals (x:: Dual ) = ForwardDiff. partials (x)
108
104
partial_vals (x:: AbstractArray{<:Dual} ) = map (ForwardDiff. partials, x)
109
105
partial_vals (x) = nothing
@@ -112,59 +108,51 @@ nodual_value(x) = x
112
108
nodual_value (x:: Dual ) = ForwardDiff. value (x)
113
109
nodual_value (x:: AbstractArray{<:Dual} ) = map (ForwardDiff. value, x)
114
110
115
-
116
- function xp_linsolve_rhs (uu, ∂_A :: Union{<:Partials, <:AbstractArray{<:Partials}} , ∂_b:: Union{<:Partials, <:AbstractArray{<:Partials}} )
111
+ function xp_linsolve_rhs (uu, ∂_A :: Union{<:Partials, <:AbstractArray{<:Partials}} ,
112
+ ∂_b:: Union{<:Partials, <:AbstractArray{<:Partials}} )
117
113
A_list = partials_to_list (∂_A)
118
- b_list = partials_to_list (∂_b)
114
+ b_list = partials_to_list (∂_b)
119
115
120
- Auu = [A* uu for A in A_list]
116
+ Auu = [A * uu for A in A_list]
121
117
122
118
b_list .- Auu
123
119
end
124
120
125
- function xp_linsolve_rhs (uu, ∂_A:: Union{<:Partials, <:AbstractArray{<:Partials}} , ∂_b:: Nothing )
121
+ function xp_linsolve_rhs (
122
+ uu, ∂_A:: Union{<:Partials, <:AbstractArray{<:Partials}} , ∂_b:: Nothing )
126
123
A_list = partials_to_list (∂_A)
127
124
128
- Auu = [A* uu for A in A_list]
125
+ Auu = [A * uu for A in A_list]
129
126
130
127
Auu
131
128
end
132
129
133
- function xp_linsolve_rhs (uu, ∂_A:: Nothing , ∂_b:: Union{<:Partials, <:AbstractArray{<:Partials}} )
130
+ function xp_linsolve_rhs (
131
+ uu, ∂_A:: Nothing , ∂_b:: Union{<:Partials, <:AbstractArray{<:Partials}} )
134
132
b_list = partials_to_list (∂_b)
135
133
136
134
b_list
137
135
end
138
136
139
-
140
-
141
137
function partials_to_list (partial_matrix:: Vector )
142
138
p = eachindex (first (partial_matrix))
143
- [[partial[i] for partial in partial_matrix] for i in p]
139
+ [[partial[i] for partial in partial_matrix] for i in p]
144
140
end
145
141
146
142
function partials_to_list (partial_matrix)
147
143
p = length (first (partial_matrix))
148
- m,n = size (partial_matrix)
149
- res_list = fill (zeros (m,n),p)
144
+ m, n = size (partial_matrix)
145
+ res_list = fill (zeros (m, n), p)
150
146
for k in 1 : p
151
- res = zeros (m,n)
147
+ res = zeros (m, n)
152
148
for i in 1 : m
153
149
for j in 1 : n
154
- res[i,j] = partial_matrix[i,j][k]
150
+ res[i, j] = partial_matrix[i, j][k]
155
151
end
156
152
end
157
153
res_list[k] = res
158
154
end
159
155
return res_list
160
156
end
161
157
162
- end
163
-
164
-
165
-
166
-
167
-
168
-
169
-
170
-
158
+ end
0 commit comments