1
1
module LinearSolveForwardDiffExt
2
2
3
3
const DualLinearProblem = LinearProblem{
4
- <: Union{Number, <:AbstractArray} , iip,
5
- <: Union{<:Dual{T,V,P}, <:AbstractArray{<:Dual{T,V,P}}} ,
6
- <: Union{<:Dual{T,V,P}, <:AbstractArray{<:Dual{T,V,P}}} ,
7
- <: Union{Number, <:AbstractArray}
8
- } where {iip, T, V}
4
+ <: Union{Number,<:AbstractArray, Nothing} , iip,
5
+ <: Union{<:Dual{T,V,P},<:AbstractArray{<:Dual{T,V,P}}} ,
6
+ <: Union{<:Dual{T,V,P},<:AbstractArray{<:Dual{T,V,P}}} ,
7
+ <: Union{Number,<:AbstractArray, SciMLBase.NullParameters }
8
+ } where {iip, T, V, P }
9
9
10
10
11
11
const DualALinearProblem = LinearProblem{
12
- <: Union{Number, <:AbstractArray} ,
12
+ <: Union{Number,<:AbstractArray, Nothing} ,
13
13
iip,
14
- <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} ,
15
- <: Union{Number, <:AbstractArray} ,
16
- <: Union{Number, <:AbstractArray}
17
- }
14
+ <: Union{<:Dual{T,V, P},<:AbstractArray{<:Dual{T,V, P}}} ,
15
+ <: Union{Number,<:AbstractArray} ,
16
+ <: Union{Number,<:AbstractArray, SciMLBase.NullParameters }
17
+ } where {iip, T, V, P}
18
18
19
19
const DualBLinearProblem = LinearProblem{
20
- <: Union{Number, <:AbstractArray} ,
20
+ <: Union{Number,<:AbstractArray, Nothing} ,
21
21
iip,
22
- <: Union{Number, <:AbstractArray} ,
23
- <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} ,
24
- <: Union{Number, <:AbstractArray}
25
- }
22
+ <: Union{Number,<:AbstractArray} ,
23
+ <: Union{<:Dual{T,V, P},<:AbstractArray{<:Dual{T,V, P}}} ,
24
+ <: Union{Number,<:AbstractArray, SciMLBase.NullParameters }
25
+ } where {iip, T, V, P}
26
26
27
27
const DualAbstractLinearProblem = Union{DualLinearProblem, DualALinearProblem, DualBLinearProblem}
28
28
29
-
30
29
function linearsolve_forwarddiff_solve (prob:: LinearProblem , alg, args... ; kwargs... )
31
30
new_A = nodual_value (prob. A)
32
31
new_b = nodual_value (prob. b)
33
32
34
- newprob = remake (prob; A = new_A, b = new_b)
33
+ newprob = remake (prob; A= new_A, b= new_b)
35
34
36
35
sol = solve (newprob, alg, args... ; kwargs... )
37
36
uu = sol. u
38
37
38
+
39
+ # Solves Dual partials separately
39
40
∂_A = partial_vals (A)
40
41
∂_b = partial_vals (b)
41
42
42
- rhs = xp_linsolve_rhs (uu, ∂_A, ∂_b)
43
+ rhs_list = xp_linsolve_rhs (uu, ∂_A, ∂_b)
43
44
44
- partial_prob = remake (newprob, b = rhs)
45
- partial_sol = solve (partial_prob, alg, args... ; kwargs... )
45
+ partial_sols = map (rhs_list) do rhs
46
+ partial_prob = remake (newprob, b= rhs)
47
+ solve (partial_prob, alg, args... ; kwargs... ). u
48
+ end
46
49
47
- sol, partial_sol
50
+ sol, partial_sols
48
51
end
49
52
53
+ function __solve (prob:: DualAbstractLinearProblem , alg, args... ; kwargs... )
54
+ sol, partials = linearsolve_forwarddiff_solve (
55
+ prob, alg, args... ; kwargs...
56
+ )
57
+
58
+ if get_dual_type (prob. A) != = nothing
59
+ dual_type = get_dual_type (prob. A)
60
+ elseif get_dual_type (prob. b) != = nothing
61
+ dual_type = get_dual_type (prob. b)
62
+ return sol
63
+ end
64
+
65
+ linearsolve_dual_solution (sol. u, partials, dual_type)
66
+
67
+ end
68
+
69
+
70
+ function linearsolve_dual_solution (
71
+ u:: Number , partials, dual_type)
72
+ return dual_type (u, partials)
73
+ end
74
+
75
+ function linearsolve_dual_solution (
76
+ u:: AbstractArray , partials, dual_type)
77
+ partials_list = RecursiveArrayTools. VectorOfArray (partials)
78
+ return map (((uᵢ, pᵢ),) -> dual_type (uᵢ, Partials (Tuple (pᵢ))), zip (u, partials_list[i, :] for i in 1 : length (partials_list[1 ])))
79
+ end
80
+
81
+
82
+ get_dual_type (x:: Dual ) = typeof (x)
83
+ get_dual_type (x:: AbstractArray{<:Dual} ) = eltype (x)
84
+ get_dual_type (x) = nothing
50
85
51
86
52
87
partial_vals (x:: Dual ) = ForwardDiff. partials (x)
53
- partial_vals (x:: AbstractArray{<:Dual} ) = map (ForwardDiff. value , x)
88
+ partial_vals (x:: AbstractArray{<:Dual} ) = map (ForwardDiff. partials , x)
54
89
partial_vals (x) = nothing
55
90
56
91
nodual_value (x) = x
@@ -64,21 +99,21 @@ function xp_linsolve_rhs(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials
64
99
65
100
Auu = [A* uu for A in A_list]
66
101
67
- reduce (hcat, b_list .- Auu)
102
+ b_list .- Auu
68
103
end
69
104
70
105
function xp_linsolve_rhs (uu, ∂_A:: Union{<:Partials, <:AbstractArray{<:Partials}} , ∂_b:: Nothing )
71
106
A_list = partials_to_list (∂_A)
72
107
73
108
Auu = [A* uu for A in A_list]
74
109
75
- reduce (hcat, Auu)
110
+ Auu
76
111
end
77
112
78
113
function xp_linsolve_rhs (uu, ∂_A:: Nothing , ∂_b:: Union{<:Partials, <:AbstractArray{<:Partials}} )
79
114
b_list = partials_to_list (∂_b)
80
115
81
- reduce (hcat, b_list)
116
+ b_list
82
117
end
83
118
84
119
0 commit comments