Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 15 additions & 16 deletions ext/LinearSolveForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,8 @@ const DualBLinearProblem = LinearProblem{
const DualAbstractLinearProblem = Union{
DualLinearProblem, DualALinearProblem, DualBLinearProblem}

LinearSolve.@concrete mutable struct DualLinearCache
LinearSolve.@concrete mutable struct DualLinearCache{DT <: Dual}
linear_cache
dual_type

partials_A
partials_b
Expand All @@ -54,7 +53,13 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
primal_b = copy(cache.linear_cache.b)
uu = sol.u

primal_sol = deepcopy(sol)
primal_sol = (;
u = recursivecopy(sol.u),
resid = recursivecopy(sol.resid),
retcode = recursivecopy(sol.retcode),
iters = recursivecopy(sol.iters),
stats = recursivecopy(sol.stats)
)

# Solves Dual partials separately
∂_A = cache.partials_A
Expand Down Expand Up @@ -103,21 +108,15 @@ function xp_linsolve_rhs(
end

function linearsolve_dual_solution(
u::Number, partials, dual_type)
return dual_type(u, partials)
end

function linearsolve_dual_solution(u::Number, partials,
dual_type::Type{<:Dual{T, V, P}}) where {T, V, P}
# Handle single-level duals
return dual_type(u, partials)
u::Number, partials, cache::DualLinearCache{DT}) where {DT}
return DT(u, partials)
end

function linearsolve_dual_solution(u::AbstractArray, partials,
dual_type::Type{<:Dual{T, V, P}}) where {T, V, P}
cache::DualLinearCache{DT}) where {DT}
# Handle single-level duals for arrays
partials_list = RecursiveArrayTools.VectorOfArray(partials)
return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))),
return map(((uᵢ, pᵢ),) -> DT(uᵢ, Partials(Tuple(pᵢ))),
zip(u, partials_list[i, :] for i in 1:length(partials_list.u[1])))
end

Expand Down Expand Up @@ -167,7 +166,7 @@ function __dual_init(
alias = alias, abstol = abstol, reltol = reltol,
maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
sensealg = sensealg, u0 = new_u0, kwargs...)
return DualLinearCache(non_partial_cache, dual_type, ∂_A, ∂_b,
return DualLinearCache{dual_type}(non_partial_cache, ∂_A, ∂_b,
!isnothing(∂_b) ? zero.(∂_b) : ∂_b, A, b, zeros(dual_type, length(b)))
end

Expand All @@ -176,11 +175,11 @@ function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)
end

function SciMLBase.solve!(
cache::DualLinearCache, alg::SciMLLinearSolveAlgorithm, args...; kwargs...)
cache::DualLinearCache{DT}, alg::SciMLLinearSolveAlgorithm, args...; kwargs...) where {DT <: ForwardDiff.Dual}
sol,
partials = linearsolve_forwarddiff_solve(
cache::DualLinearCache, cache.alg, args...; kwargs...)
dual_sol = linearsolve_dual_solution(sol.u, partials, cache.dual_type)
dual_sol = linearsolve_dual_solution(sol.u, partials, cache)

if cache.dual_u isa AbstractArray
cache.dual_u[:] = dual_sol
Expand Down
Loading