diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index cf38b94687..5a8687984f 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -685,7 +685,9 @@ function SciMLBase.late_binding_update_u0_p( return newu0, newp end -function DiffEqBase.get_updated_symbolic_problem(sys::AbstractSystem, prob; kw...) +function DiffEqBase.get_updated_symbolic_problem( + sys::AbstractSystem, prob; u0 = state_values(prob), + p = parameter_values(prob), kw...) supports_initialization(sys) || return prob initdata = prob.f.initialization_data initdata isa SciMLBase.OverrideInitData || return prob @@ -693,10 +695,8 @@ function DiffEqBase.get_updated_symbolic_problem(sys::AbstractSystem, prob; kw.. meta isa InitializationMetadata || return prob meta.get_updated_u0 === nothing && return prob - u0 = state_values(prob) - u0 === nothing && return prob + u0 === nothing && return remake(prob; p) - p = parameter_values(prob) t0 = is_time_dependent(prob) ? current_time(prob) : nothing if p isa MTKParameters @@ -713,7 +713,7 @@ function DiffEqBase.get_updated_symbolic_problem(sys::AbstractSystem, prob; kw.. T = StaticArrays.similar_type(u0) end - return remake(prob; u0 = T(meta.get_updated_u0(prob, initdata.initializeprob))) + return remake(prob; u0 = T(meta.get_updated_u0(prob, initdata.initializeprob)), p) end """ diff --git a/test/extensions/Project.toml b/test/extensions/Project.toml index 9f43e6f4a4..f1647aef1a 100644 --- a/test/extensions/Project.toml +++ b/test/extensions/Project.toml @@ -23,6 +23,7 @@ OrdinaryDiffEqVerner = "79d7bb75-1356-48c1-b8c0-6832512096c2" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" SimpleDiffEq = "05bca326-078c-5bf0-a5bf-ce7c7982d7fd" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/extensions/ad.jl b/test/extensions/ad.jl index 14649b6bb6..b5e71bf0af 100644 --- a/test/extensions/ad.jl +++ b/test/extensions/ad.jl @@ -8,6 +8,7 @@ using OrdinaryDiffEqNonlinearSolve using NonlinearSolve using SciMLSensitivity using ForwardDiff +using StableRNGs using ChainRulesCore using ChainRulesCore: NoTangent using ChainRulesTestUtils: test_rrule, rand_tangent @@ -135,3 +136,46 @@ end prob[sys.x] end end + +@testset "`p` provided to `solve` is respected" begin + @mtkmodel Linear begin + @variables begin + x(t) = 1.0, [description = "Prey"] + end + @parameters begin + α = 1.5 + end + @equations begin + D(x) ~ -α * x + end + end + + @mtkbuild linear = Linear() + problem = ODEProblem(linear, [], (0.0, 1.0)) + solution = solve(problem, Tsit5(), saveat = 0.1) + rng = StableRNG(42) + data = (; + t = solution.t, + # [[y, x], :] + measurements = Array(solution) + ) + data.measurements .+= 0.05 * randn(rng, size(data.measurements)) + + p0, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), problem.p) + + objective = let repack = repack, problem = problem + (p, data) -> begin + pnew = repack(p) + sol = solve(problem, Tsit5(), p = pnew, saveat = data.t) + sum(abs2, sol .- data.measurements) / size(data.t, 1) + end + end + + # Check 0.0031677344878386607 + @test_nowarn objective(p0, data) + + fd = ForwardDiff.gradient(Base.Fix2(objective, data), p0) + zg = Zygote.gradient(Base.Fix2(objective, data), p0) + + @test fd≈zg[1] atol=1e-6 +end