1
1
module LinearSolveForwardDiffExt
2
2
3
+ using LinearSolve
4
+ using ForwardDiff
5
+ using ForwardDiff: Dual, Partials
6
+ using SciMLBase
7
+ using RecursiveArrayTools
8
+
3
9
const DualLinearProblem = LinearProblem{
4
10
<: Union{Number,<:AbstractArray, Nothing} ,iip,
5
11
<: Union{<:Dual{T,V,P},<:AbstractArray{<:Dual{T,V,P}}} ,
@@ -27,6 +33,7 @@ const DualBLinearProblem = LinearProblem{
27
33
const DualAbstractLinearProblem = Union{DualLinearProblem, DualALinearProblem, DualBLinearProblem}
28
34
29
35
function linearsolve_forwarddiff_solve (prob:: LinearProblem , alg, args... ; kwargs... )
36
+ @info " here!"
30
37
new_A = nodual_value (prob. A)
31
38
new_b = nodual_value (prob. b)
32
39
@@ -37,8 +44,8 @@ function linearsolve_forwarddiff_solve(prob::LinearProblem, alg, args...; kwargs
37
44
38
45
39
46
# Solves Dual partials separately
40
- ∂_A = partial_vals (A)
41
- ∂_b = partial_vals (b)
47
+ ∂_A = partial_vals (prob . A)
48
+ ∂_b = partial_vals (prob . b)
42
49
43
50
rhs_list = xp_linsolve_rhs (uu, ∂_A, ∂_b)
44
51
56
63
57
64
function SciMLBase. solve (prob:: DualAbstractLinearProblem , :: Nothing , args... ;
58
65
assump = OperatorAssumptions (issquare (prob. A)), kwargs... )
59
- return solve (prob, defaultalg (prob. A, prob. b, assump), args... ; kwargs... )
66
+ return solve (prob, LinearSolve . defaultalg (prob. A, prob. b, assump), args... ; kwargs... )
60
67
end
61
68
62
- function SciMLBase. solve (prob:: DualAbstractLinearProblem , alg, args... ; kwargs... )
69
+ function SciMLBase. solve (prob:: DualAbstractLinearProblem , alg:: LinearSolve.SciMLLinearSolveAlgorithm , args... ; kwargs... )
63
70
sol, partials = linearsolve_forwarddiff_solve (
64
71
prob, alg, args... ; kwargs...
65
72
)
@@ -152,7 +159,7 @@ function partials_to_list(partial_matrix)
152
159
return res_list
153
160
end
154
161
155
-
162
+ end
156
163
157
164
158
165
0 commit comments