Skip to content

Commit 677570f

Browse files
committed
add partial linsolve
1 parent c57c9d8 commit 677570f

File tree

1 file changed

+10
-18
lines changed

1 file changed

+10
-18
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,12 @@ function linearsolve_forwarddiff_solve(prob::LinearProblem, alg, args...; kwargs
3939
∂_A = partial_vals(A)
4040
∂_b = partial_vals(b)
4141

42-
42+
rhs = xp_linsolve_rhs(uu, ∂_A, ∂_b)
4343

44-
if uu isa Number
45-
46-
else
47-
48-
end
44+
partial_prob = remake(newprob, b = rhs)
45+
partial_sol = solve(partial_prob, alg, args...; kwargs...)
4946

47+
sol, partial_sol
5048
end
5149

5250

@@ -60,33 +58,27 @@ nodual_value(x::Dual) = ForwardDiff.value(x)
6058
nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)
6159

6260

63-
function x_p_linsolve(new_A, uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}}, ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}})
61+
function xp_linsolve_rhs(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}}, ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}})
6462
A_list = partials_to_list(∂_A)
6563
b_list = partials_to_list(∂_b)
6664

6765
Auu = [A*uu for A in A_list]
6866

69-
linsol_rhs = reduce(hcat, b_list .- Auu)
70-
71-
new_A \ linsol_rhs
67+
reduce(hcat, b_list .- Auu)
7268
end
7369

74-
function x_p_linsolve(new_A, uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}}, ∂_b::Nothing)
70+
function xp_linsolve_rhs(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}}, ∂_b::Nothing)
7571
A_list = partials_to_list(∂_A)
7672

7773
Auu = [A*uu for A in A_list]
7874

79-
linsol_rhs = reduce(hcat, Auu)
80-
81-
new_A \ linsol_rhs
75+
reduce(hcat, Auu)
8276
end
8377

84-
function x_p_linsolve(new_A, uu, ∂_A::Nothing, ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}})
78+
function xp_linsolve_rhs(uu, ∂_A::Nothing, ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}})
8579
b_list = partials_to_list(∂_b)
8680

87-
linsol_rhs = reduce(hcat, b_list)
88-
89-
new_A \ linsol_rhs
81+
reduce(hcat, b_list)
9082
end
9183

9284

0 commit comments

Comments
 (0)