Skip to content

Commit c419c48

Browse files
committed
use real solve
1 parent fc8c4b5 commit c419c48

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,16 @@ function linearsolve_forwarddiff_solve(prob::LinearProblem, alg, args...; kwargs
5050
sol, partial_sols
5151
end
5252

53-
function __solve(prob::DualAbstractLinearProblem, alg, args...; kwargs...)
53+
function SciMLBase.solve(prob::DualAbstractLinearProblem, args...; kwargs...)
54+
return solve(prob, nothing, args...; kwargs...)
55+
end
56+
57+
function SciMLBase.solve(prob::DualAbstractLinearProblem, ::Nothing, args...;
58+
assump = OperatorAssumptions(issquare(prob.A)), kwargs...)
59+
return solve(prob, defaultalg(prob.A, prob.b, assump), args...; kwargs...)
60+
end
61+
62+
function SciMLBase.solve(prob::DualAbstractLinearProblem, alg, args...; kwargs...)
5463
sol, partials = linearsolve_forwarddiff_solve(
5564
prob, alg, args...; kwargs...
5665
)
@@ -59,10 +68,14 @@ function __solve(prob::DualAbstractLinearProblem, alg, args...; kwargs...)
5968
dual_type = get_dual_type(prob.A)
6069
elseif get_dual_type(prob.b) !== nothing
6170
dual_type = get_dual_type(prob.b)
62-
return sol
6371
end
6472

65-
linearsolve_dual_solution(sol.u, partials, dual_type)
73+
dual_sol = linearsolve_dual_solution(sol.u, partials, dual_type)
74+
75+
return SciMLBase.build_linear_solution(
76+
alg, dual_sol, sol.resid, sol.cache; sol.retcode, sol.iters, sol.stats
77+
)
78+
6679

6780
end
6881

0 commit comments

Comments
 (0)