diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index 93fa988b7d..9c8ec73de2 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -772,7 +772,9 @@ function SciMLBase.late_binding_update_u0_p( return newu0, newp end -function DiffEqBase.get_updated_symbolic_problem(sys::AbstractSystem, prob; kw...) +function DiffEqBase.get_updated_symbolic_problem( + sys::AbstractSystem, prob; u0 = state_values(prob), + p = parameter_values(prob), kw...) supports_initialization(sys) || return prob initdata = prob.f.initialization_data initdata isa SciMLBase.OverrideInitData || return prob @@ -780,10 +782,8 @@ function DiffEqBase.get_updated_symbolic_problem(sys::AbstractSystem, prob; kw.. meta isa InitializationMetadata || return prob meta.get_updated_u0 === nothing && return prob - u0 = state_values(prob) - u0 === nothing && return prob + u0 === nothing && return remake(prob; p) - p = parameter_values(prob) t0 = is_time_dependent(prob) ? current_time(prob) : nothing if p isa MTKParameters @@ -800,7 +800,7 @@ function DiffEqBase.get_updated_symbolic_problem(sys::AbstractSystem, prob; kw.. T = StaticArrays.similar_type(u0) end - return remake(prob; u0 = T(meta.get_updated_u0(prob, initdata.initializeprob))) + return remake(prob; u0 = T(meta.get_updated_u0(prob, initdata.initializeprob)), p) end """ diff --git a/test/extensions/Project.toml b/test/extensions/Project.toml index 6a8d01a7b8..a5ab9a8d0c 100644 --- a/test/extensions/Project.toml +++ b/test/extensions/Project.toml @@ -24,6 +24,7 @@ Pyomo = "0e8e1daf-01b5-4eba-a626-3897743a3816" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" SimpleDiffEq = "05bca326-078c-5bf0-a5bf-ce7c7982d7fd" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/extensions/ad.jl b/test/extensions/ad.jl index 7e9cbdd740..53210b66a8 100644 --- a/test/extensions/ad.jl +++ b/test/extensions/ad.jl @@ -8,6 +8,7 @@ using OrdinaryDiffEqNonlinearSolve using NonlinearSolve using SciMLSensitivity using ForwardDiff +using StableRNGs using ChainRulesCore using ChainRulesCore: NoTangent using ChainRulesTestUtils: test_rrule, rand_tangent @@ -136,3 +137,46 @@ end prob[sys.x] end end + +@testset "`p` provided to `solve` is respected" begin + @mtkmodel Linear begin + @variables begin + x(t) = 1.0, [description = "Prey"] + end + @parameters begin + α = 1.5 + end + @equations begin + D(x) ~ -α * x + end + end + + @mtkcompile linear = Linear() + problem = ODEProblem(linear, [], (0.0, 1.0)) + solution = solve(problem, Tsit5(), saveat = 0.1) + rng = StableRNG(42) + data = (; + t = solution.t, + # [[y, x], :] + measurements = Array(solution) + ) + data.measurements .+= 0.05 * randn(rng, size(data.measurements)) + + p0, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), problem.p) + + objective = let repack = repack, problem = problem + (p, data) -> begin + pnew = repack(p) + sol = solve(problem, Tsit5(), p = pnew, saveat = data.t) + sum(abs2, sol .- data.measurements) / size(data.t, 1) + end + end + + # Check 0.0031677344878386607 + @test_nowarn objective(p0, data) + + fd = ForwardDiff.gradient(Base.Fix2(objective, data), p0) + zg = Zygote.gradient(Base.Fix2(objective, data), p0) + + @test fd≈zg[1] atol=1e-6 +end