diff --git a/src/remake.jl b/src/remake.jl index 750413a1f..280957671 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -125,8 +125,8 @@ function remake(prob::ODEProblem; f = missing, if f === missing if build_initializeprob - initialization_data = remake_initialization_data( - prob.f.sys, prob.f, u0, tspan[1], p) + initialization_data = remake_initialization_data_compat_wrapper( + prob.f.sys, prob.f, u0, tspan[1], p, newu0, newp) else initialization_data = nothing end @@ -203,16 +203,32 @@ function remake_initializeprob(sys, scimlfn, u0, t0, p) end """ - remake_initialization_data(sys, scimlfn, u0, t0, p) + $(TYPEDSIGNATURES) + +Wrapper around `remake_initialization_data` for backward compatibility when `newu0` and +`newp` were not arguments. +""" +function remake_initialization_data_compat_wrapper(sys, scimlfn, u0, t0, p, newu0, newp) + if hasmethod(remake_initialization_data, + Tuple{typeof(sys), typeof(scimlfn), typeof(u0), typeof(t0), typeof(p)}) + remake_initialization_data(sys, scimlfn, u0, t0, p) + else + remake_initialization_data(sys, scimlfn, u0, t0, p, newu0, newp) + end +end + +""" + remake_initialization_data(sys, scimlfn, u0, t0, p, newu0, newp) Re-create the initialization data present in the function `scimlfn`, using the -associated system `sys` and the user provided new values of `u0`, initial time `t0` and -`p`. By default, this calls `remake_initializeprob` for backward compatibility and -attempts to construct an `OverrideInitData` from the result. +associated system `sys`, the user provided new values of `u0`, initial time `t0`, +user-provided `p`, new u0 vector `newu0` and new parameter object `newp`. By default, +this calls `remake_initializeprob` for backward compatibility and attempts to construct +an `OverrideInitData` from the result. Note that `u0` or `p` may be `missing` if the user does not provide a value for them. """ -function remake_initialization_data(sys, scimlfn, u0, t0, p) +function remake_initialization_data(sys, scimlfn, u0, t0, p, newu0, newp) return reconstruct_initialization_data( nothing, remake_initializeprob(sys, scimlfn, u0, t0, p)...) end diff --git a/test/downstream/adjoints.jl b/test/downstream/adjoints.jl index 327172ef7..4e75e19b6 100644 --- a/test/downstream/adjoints.jl +++ b/test/downstream/adjoints.jl @@ -68,7 +68,7 @@ gs_ts, = Zygote.gradient(sol) do sol sum(sum.(sol[[lorenz1.x, lorenz2.x], :])) end -@test_broken all(map(x -> x == true_grad_vecsym, gs_ts)) +@test all(map(x -> x == true_grad_vecsym, gs_ts)) # BatchedInterface AD @variables x(t)=1.0 y(t)=1.0 z(t)=1.0 w(t)=1.0