Skip to content

Commit d6e39bd

Browse files
committed
add solve_adjoint and solve_forward
1 parent aa72cf2 commit d6e39bd

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed

lib/NonlinearSolveBase/src/solve.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,69 @@ function CommonSolve.solve!(cache::NonlinearSolveNoInitCache)
791791
return CommonSolve.solve(cache.prob, cache.alg, cache.args...; cache.kwargs...)
792792
end
793793

794+
function _solve_adjoint(prob, sensealg, u0, p, originator, args...; merge_callbacks = true,
795+
kwargs...)
796+
alg = extract_alg(args, kwargs, prob.kwargs)
797+
if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling
798+
_prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0,
799+
p = p, kwargs...)
800+
else
801+
_prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...)
802+
end
803+
804+
if has_kwargs(_prob)
805+
if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback)
806+
kwargs_temp = NamedTuple{
807+
Base.diff_names(Base._nt_names(values(kwargs)),
808+
(:callback,))}(values(kwargs))
809+
callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet(
810+
_prob.kwargs[:callback],
811+
values(kwargs).callback),))
812+
kwargs = merge(kwargs_temp, callbacks)
813+
end
814+
kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs)
815+
end
816+
817+
if length(args) > 1
818+
_concrete_solve_adjoint(_prob, alg, sensealg, u0, p, originator,
819+
Base.tail(args)...; kwargs...)
820+
else
821+
_concrete_solve_adjoint(_prob, alg, sensealg, u0, p, originator; kwargs...)
822+
end
823+
end
824+
825+
function _solve_forward(prob, sensealg, u0, p, originator, args...; merge_callbacks = true,
826+
kwargs...)
827+
alg = extract_alg(args, kwargs, prob.kwargs)
828+
if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling
829+
_prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0,
830+
p = p, kwargs...)
831+
else
832+
_prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...)
833+
end
834+
835+
if has_kwargs(_prob)
836+
if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback)
837+
kwargs_temp = NamedTuple{
838+
Base.diff_names(Base._nt_names(values(kwargs)),
839+
(:callback,))}(values(kwargs))
840+
callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet(
841+
_prob.kwargs[:callback],
842+
values(kwargs).callback),))
843+
kwargs = merge(kwargs_temp, callbacks)
844+
end
845+
kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs)
846+
end
847+
848+
if length(args) > 1
849+
_concrete_solve_forward(_prob, alg, sensealg, u0, p, originator,
850+
Base.tail(args)...; kwargs...)
851+
else
852+
_concrete_solve_forward(_prob, alg, sensealg, u0, p, originator; kwargs...)
853+
end
854+
end
855+
856+
794857
function get_concrete_problem(prob::NonlinearProblem, isadapt; kwargs...)
795858
oldprob = prob
796859
prob = get_updated_symbolic_problem(get_root_indp(prob), prob; kwargs...)

0 commit comments

Comments
 (0)