@@ -46,15 +46,16 @@ function inplace_vjp(prob, u0, p, verbose, repack)
4646
4747 vjp = try
4848 f = unwrapped_f (prob. f)
49+ tspan_ = prob isa AbstractNonlinearProblem ? nothing : [prob. tspan[1 ]]
4950 if p === nothing || p isa SciMLBase. NullParameters
50- ReverseDiff. GradientTape ((copy (u0), [prob . tspan[ 1 ]] )) do u, t
51+ ReverseDiff. GradientTape ((copy (u0), tspan_ )) do u, t
5152 du1 = similar (u, size (u))
5253 du1 .= 0
5354 f (du1, u, p, first (t))
5455 return vec (du1)
5556 end
5657 else
57- ReverseDiff. GradientTape ((copy (u0), p, [prob . tspan[ 1 ]] )) do u, p, t
58+ ReverseDiff. GradientTape ((copy (u0), p, tspan_ )) do u, p, t
5859 du1 = similar (u, size (u))
5960 du1 .= 0
6061 f (du1, u, repack (p), first (t))
@@ -299,6 +300,7 @@ function DiffEqBase._concrete_solve_adjoint(
299300 tunables, repack = Functors. functor (p)
300301 end
301302
303+ u0 = state_values (prob) === nothing ? Float64[] : u0
302304 default_sensealg = automatic_sensealg_choice (prob, u0, tunables, verbose, repack)
303305 DiffEqBase. _concrete_solve_adjoint (prob, alg, default_sensealg, u0, p,
304306 originator:: SciMLBase.ADOriginator , args... ; verbose,
@@ -371,6 +373,7 @@ function DiffEqBase._concrete_solve_adjoint(
371373 args... ; save_start = true , save_end = true ,
372374 saveat = eltype (prob. tspan)[],
373375 save_idxs = nothing ,
376+ initializealg_default = SciMLBase. OverrideInit (; abstol = 1e-6 , reltol = 1e-3 ),
374377 kwargs... )
375378 if ! (sensealg isa GaussAdjoint) &&
376379 ! (p isa Union{Nothing, SciMLBase. NullParameters, AbstractArray}) ||
@@ -412,16 +415,61 @@ function DiffEqBase._concrete_solve_adjoint(
412415 Base. diff_names (Base. _nt_names (values (kwargs)),
413416 (:callback_adj , :callback ))}(values (kwargs))
414417 isq = sensealg isa QuadratureAdjoint
418+ kwargs_init = kwargs_adj[Base. diff_names (Base. _nt_names (kwargs_adj), (:initializealg ,))]
419+
420+ if haskey (kwargs, :initializealg ) || haskey (prob. kwargs, :initializealg )
421+ initializealg = haskey (kwargs, :initializealg ) ? kwargs[:initializealg ] : prob. kwargs[:initializealg ]
422+ else
423+ initializealg = DefaultInit ()
424+ end
425+
426+ default_inits = Union{OverrideInit, Nothing, DefaultInit}
427+ igs, new_u0, new_p, new_initializealg = if (SciMLBase. has_initialization_data (_prob. f) && initializealg isa default_inits)
428+ local new_u0
429+ local new_p
430+ initializeprob = prob. f. initialization_data. initializeprob
431+ iu0 = state_values (initializeprob)
432+ isAD = if iu0 === nothing
433+ AutoForwardDiff
434+ elseif has_autodiff (alg)
435+ OrdinaryDiffEqCore. alg_autodiff (alg) isa AutoForwardDiff
436+ else
437+ true
438+ end
439+ nlsolve_alg = default_nlsolve (nothing , Val (isinplace (_prob)), iu0, initializeprob, isAD)
440+ initializealg = initializealg isa Union{Nothing, DefaultInit} ? initializealg_default : initializealg
441+
442+ iy, back = Zygote. pullback (tunables) do tunables
443+ new_prob = remake (_prob, p = repack (tunables))
444+ new_u0, new_p, _ = SciMLBase. get_initial_values (new_prob, new_prob, new_prob. f, initializealg, Val (isinplace (new_prob));
445+ sensealg = SteadyStateAdjoint (autojacvec = sensealg. autojacvec),
446+ nlsolve_alg,
447+ kwargs_init... )
448+ new_tunables, _, _ = SciMLStructures. canonicalize (SciMLStructures. Tunable (), new_p)
449+ if SciMLBase. initialization_status (_prob) == SciMLBase. OVERDETERMINED
450+ sum (new_tunables)
451+ else
452+ sum (new_u0) + sum (new_tunables)
453+ end
454+ end
455+ igs = back (one (iy))[1 ] .- one (eltype (tunables))
456+
457+ igs, new_u0, new_p, SciMLBase. NoInit ()
458+ else
459+ nothing , u0, p, initializealg
460+ end
461+ _prob = remake (_prob, u0 = new_u0, p = new_p)
462+
415463 if sensealg isa BacksolveAdjoint
416- sol = solve (_prob, alg, args... ; save_noise = true ,
464+ sol = solve (_prob, alg, args... ; initializealg = new_initializealg, save_noise = true ,
417465 save_start = save_start, save_end = save_end,
418466 saveat = saveat, kwargs_fwd... )
419467 elseif ischeckpointing (sensealg)
420- sol = solve (_prob, alg, args... ; save_noise = true ,
468+ sol = solve (_prob, alg, args... ; initializealg = new_initializealg, save_noise = true ,
421469 save_start = true , save_end = true ,
422470 saveat = saveat, kwargs_fwd... )
423471 else
424- sol = solve (_prob, alg, args... ; save_noise = true , save_start = true ,
472+ sol = solve (_prob, alg, args... ; initializealg = new_initializealg, save_noise = true , save_start = true ,
425473 save_end = true , kwargs_fwd... )
426474 end
427475
@@ -491,6 +539,7 @@ function DiffEqBase._concrete_solve_adjoint(
491539 _save_idxs = save_idxs === nothing ? Colon () : save_idxs
492540
493541 function adjoint_sensitivity_backpass (Δ)
542+ Δ = Δ isa AbstractThunk ? unthunk (Δ) : Δ
494543 function df_iip (_out, u, p, t, i)
495544 outtype = _out isa SubArray ?
496545 ArrayInterface. parameterless_type (_out. parent) :
@@ -628,20 +677,22 @@ function DiffEqBase._concrete_solve_adjoint(
628677 dgdu_discrete = df_iip,
629678 sensealg = sensealg,
630679 callback = cb2,
631- kwargs_adj ... )
680+ kwargs_init ... )
632681 else
633682 du0, dp = adjoint_sensitivities (sol, alg, args... ; t = ts,
634683 dgdu_discrete = df_oop,
635684 sensealg = sensealg,
636685 callback = cb2,
637- kwargs_adj ... )
686+ kwargs_init ... )
638687 end
639688
640689 du0 = reshape (du0, size (u0))
641690
642691 dp = p === nothing || p === DiffEqBase. NullParameters () ? nothing :
643692 dp isa AbstractArray ? reshape (dp' , size (tunables)) : dp
644693
694+ dp = Zygote. accum (dp, igs)
695+
645696 _, repack_adjoint = if p === nothing || p === DiffEqBase. NullParameters () ||
646697 ! isscimlstructure (p)
647698 nothing , x -> (x,)
@@ -1679,6 +1730,7 @@ function DiffEqBase._concrete_solve_adjoint(
16791730 u0, p, originator:: SciMLBase.ADOriginator ,
16801731 args... ; save_idxs = nothing , kwargs... )
16811732 _prob = remake (prob, u0 = u0, p = p)
1733+
16821734 sol = solve (_prob, alg, args... ; kwargs... )
16831735 _save_idxs = save_idxs === nothing ? Colon () : save_idxs
16841736
@@ -1688,26 +1740,74 @@ function DiffEqBase._concrete_solve_adjoint(
16881740 out = SciMLBase. sensitivity_solution (sol, sol[_save_idxs])
16891741 end
16901742
1743+ _, repack_adjoint = if isscimlstructure (p)
1744+ Zygote. pullback (p) do p
1745+ t, _, _ = canonicalize (Tunable (), p)
1746+ t
1747+ end
1748+ elseif isfunctor (p)
1749+ ps, re = Functors. functor (p)
1750+ ps, x -> (re (x),)
1751+ else
1752+ nothing , x -> (x,)
1753+ end
1754+
16911755 function steadystatebackpass (Δ)
1756+ Δ = Δ isa AbstractThunk ? unthunk (Δ) : Δ
16921757 # Δ = dg/dx or diffcache.dg_val
16931758 # del g/del p = 0
16941759 function df (_out, u, p, t, i)
16951760 if _save_idxs isa Number
16961761 _out[_save_idxs] = Δ[_save_idxs]
16971762 elseif Δ isa Number
16981763 @. _out[_save_idxs] = Δ
1699- else
1764+ elseif Δ isa AbstractArray{ <: AbstractArray } || Δ isa AbstractVectorOfArray || Δ isa AbstractArray
17001765 @. _out[_save_idxs] = Δ[_save_idxs]
1766+ elseif isnothing (_out)
1767+ _out
1768+ else
1769+ @. _out[_save_idxs] = Δ. u[_save_idxs]
1770+ end
1771+ end
1772+ dp = adjoint_sensitivities (sol, alg; sensealg = sensealg, dgdu = df, initializealg = BrownFullBasicInit ())
1773+
1774+ dp, Δtunables = if Δ isa AbstractArray || Δ isa Number
1775+ # if Δ isa AbstractArray, the gradients correspond to `u`
1776+ # this is something that needs changing in the future, but
1777+ # this is the applicable till the movement to structuaral
1778+ # tangents is completed
1779+ dp, Δtunables = if isscimlstructure (dp)
1780+ dp, _, _ = canonicalize (Tunable (), dp)
1781+ dp, nothing
1782+ elseif isfunctor (dp)
1783+ dp, _ = Functors. functor (dp)
1784+ dp, nothing
1785+ else
1786+ dp, nothing
1787+ end
1788+ else
1789+ dp, Δtunables = if isscimlstructure (p)
1790+ Δp = setproperties (dp, to_nt (Δ. prob. p))
1791+ Δtunables, _, _ = canonicalize (Tunable (), Δp)
1792+ dp, _, _ = canonicalize (Tunable (), dp)
1793+ dp, Δtunables
1794+ elseif isfunctor (p)
1795+ dp, _ = Functors. functor (dp)
1796+ Δtunables, _ = Functors. functor (Δ. prob. p)
1797+ dp, Δtunables
1798+ else
1799+ dp, Δ. prob. p
17011800 end
17021801 end
1703- dp = adjoint_sensitivities (sol, alg; sensealg = sensealg, dgdu = df)
1802+
1803+ dp = Zygote. accum (dp, (isnothing (Δtunables) || isempty (Δtunables)) ? nothing : Δtunables)
17041804
17051805 if originator isa SciMLBase. TrackerOriginator ||
17061806 originator isa SciMLBase. ReverseDiffOriginator
1707- (NoTangent (), NoTangent (), NoTangent (), dp , NoTangent (),
1807+ (NoTangent (), NoTangent (), NoTangent (), repack_adjoint (dp)[ 1 ] , NoTangent (),
17081808 ntuple (_ -> NoTangent (), length (args))... )
17091809 else
1710- (NoTangent (), NoTangent (), NoTangent (), NoTangent (), dp , NoTangent (),
1810+ (NoTangent (), NoTangent (), NoTangent (), NoTangent (), repack_adjoint (dp)[ 1 ] , NoTangent (),
17111811 ntuple (_ -> NoTangent (), length (args))... )
17121812 end
17131813 end
0 commit comments