@@ -791,6 +791,69 @@ function CommonSolve.solve!(cache::NonlinearSolveNoInitCache)
791
791
return CommonSolve. solve (cache. prob, cache. alg, cache. args... ; cache. kwargs... )
792
792
end
793
793
794
+ function _solve_adjoint (prob, sensealg, u0, p, originator, args... ; merge_callbacks = true ,
795
+ kwargs... )
796
+ alg = extract_alg (args, kwargs, prob. kwargs)
797
+ if isnothing (alg) || ! (alg isa AbstractDEAlgorithm) # Default algorithm handling
798
+ _prob = get_concrete_problem (prob, ! (prob isa DiscreteProblem); u0 = u0,
799
+ p = p, kwargs... )
800
+ else
801
+ _prob = get_concrete_problem (prob, isadaptive (alg); u0 = u0, p = p, kwargs... )
802
+ end
803
+
804
+ if has_kwargs (_prob)
805
+ if merge_callbacks && haskey (_prob. kwargs, :callback ) && haskey (kwargs, :callback )
806
+ kwargs_temp = NamedTuple{
807
+ Base. diff_names (Base. _nt_names (values (kwargs)),
808
+ (:callback ,))}(values (kwargs))
809
+ callbacks = NamedTuple {(:callback,)} ((DiffEqBase. CallbackSet (
810
+ _prob. kwargs[:callback ],
811
+ values (kwargs). callback),))
812
+ kwargs = merge (kwargs_temp, callbacks)
813
+ end
814
+ kwargs = isempty (_prob. kwargs) ? kwargs : merge (values (_prob. kwargs), kwargs)
815
+ end
816
+
817
+ if length (args) > 1
818
+ _concrete_solve_adjoint (_prob, alg, sensealg, u0, p, originator,
819
+ Base. tail (args)... ; kwargs... )
820
+ else
821
+ _concrete_solve_adjoint (_prob, alg, sensealg, u0, p, originator; kwargs... )
822
+ end
823
+ end
824
+
825
+ function _solve_forward (prob, sensealg, u0, p, originator, args... ; merge_callbacks = true ,
826
+ kwargs... )
827
+ alg = extract_alg (args, kwargs, prob. kwargs)
828
+ if isnothing (alg) || ! (alg isa AbstractDEAlgorithm) # Default algorithm handling
829
+ _prob = get_concrete_problem (prob, ! (prob isa DiscreteProblem); u0 = u0,
830
+ p = p, kwargs... )
831
+ else
832
+ _prob = get_concrete_problem (prob, isadaptive (alg); u0 = u0, p = p, kwargs... )
833
+ end
834
+
835
+ if has_kwargs (_prob)
836
+ if merge_callbacks && haskey (_prob. kwargs, :callback ) && haskey (kwargs, :callback )
837
+ kwargs_temp = NamedTuple{
838
+ Base. diff_names (Base. _nt_names (values (kwargs)),
839
+ (:callback ,))}(values (kwargs))
840
+ callbacks = NamedTuple {(:callback,)} ((DiffEqBase. CallbackSet (
841
+ _prob. kwargs[:callback ],
842
+ values (kwargs). callback),))
843
+ kwargs = merge (kwargs_temp, callbacks)
844
+ end
845
+ kwargs = isempty (_prob. kwargs) ? kwargs : merge (values (_prob. kwargs), kwargs)
846
+ end
847
+
848
+ if length (args) > 1
849
+ _concrete_solve_forward (_prob, alg, sensealg, u0, p, originator,
850
+ Base. tail (args)... ; kwargs... )
851
+ else
852
+ _concrete_solve_forward (_prob, alg, sensealg, u0, p, originator; kwargs... )
853
+ end
854
+ end
855
+
856
+
794
857
function get_concrete_problem (prob:: NonlinearProblem , isadapt; kwargs... )
795
858
oldprob = prob
796
859
prob = get_updated_symbolic_problem (get_root_indp (prob), prob; kwargs... )
0 commit comments