Skip to content

Commit 81030fa

Browse files
committed
add imports and fix partial_val
1 parent d6bddf9 commit 81030fa

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
module LinearSolveForwardDiffExt
22

3+
using LinearSolve
4+
using ForwardDiff
5+
using ForwardDiff: Dual, Partials
6+
using SciMLBase
7+
using RecursiveArrayTools
8+
39
const DualLinearProblem = LinearProblem{
410
<:Union{Number,<:AbstractArray, Nothing},iip,
511
<:Union{<:Dual{T,V,P},<:AbstractArray{<:Dual{T,V,P}}},
@@ -27,6 +33,7 @@ const DualBLinearProblem = LinearProblem{
2733
const DualAbstractLinearProblem = Union{DualLinearProblem, DualALinearProblem, DualBLinearProblem}
2834

2935
function linearsolve_forwarddiff_solve(prob::LinearProblem, alg, args...; kwargs...)
36+
@info "here!"
3037
new_A = nodual_value(prob.A)
3138
new_b = nodual_value(prob.b)
3239

@@ -37,8 +44,8 @@ function linearsolve_forwarddiff_solve(prob::LinearProblem, alg, args...; kwargs
3744

3845

3946
# Solves Dual partials separately
40-
∂_A = partial_vals(A)
41-
∂_b = partial_vals(b)
47+
∂_A = partial_vals(prob.A)
48+
∂_b = partial_vals(prob.b)
4249

4350
rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b)
4451

@@ -56,10 +63,10 @@ end
5663

5764
function SciMLBase.solve(prob::DualAbstractLinearProblem, ::Nothing, args...;
5865
assump = OperatorAssumptions(issquare(prob.A)), kwargs...)
59-
return solve(prob, defaultalg(prob.A, prob.b, assump), args...; kwargs...)
66+
return solve(prob, LinearSolve.defaultalg(prob.A, prob.b, assump), args...; kwargs...)
6067
end
6168

62-
function SciMLBase.solve(prob::DualAbstractLinearProblem, alg, args...; kwargs...)
69+
function SciMLBase.solve(prob::DualAbstractLinearProblem, alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; kwargs...)
6370
sol, partials = linearsolve_forwarddiff_solve(
6471
prob, alg, args...; kwargs...
6572
)
@@ -152,7 +159,7 @@ function partials_to_list(partial_matrix)
152159
return res_list
153160
end
154161

155-
162+
end
156163

157164

158165

0 commit comments

Comments
 (0)