diff --git a/Project.toml b/Project.toml index c7057a507..db4821fd5 100644 --- a/Project.toml +++ b/Project.toml @@ -102,8 +102,6 @@ OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" -PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" -PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" RCall = "6f49c342-dc21-5d91-9882-a32aef131414" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" @@ -116,4 +114,4 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Pkg", "Plots", "UnicodePlots", "PyCall", "PythonCall", "SafeTestsets", "Serialization", "Test", "StableRNGs", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "OrdinaryDiffEq", "ForwardDiff", "Tables"] +test = ["Pkg", "Plots", "UnicodePlots", "SafeTestsets", "Serialization", "Test", "StableRNGs", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "OrdinaryDiffEq", "ForwardDiff", "Tables"] diff --git a/src/remake.jl b/src/remake.jl index b84625fde..e4540cd68 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -567,6 +567,7 @@ end function fill_u0(prob, u0; defs = nothing, use_defaults = false) vsyms = variable_symbols(prob) + idx_to_vsym = anydict(variable_index(prob, sym) => sym for sym in vsyms) sym_to_idx = anydict() idx_to_sym = anydict() idx_to_val = anydict() @@ -580,6 +581,8 @@ function fill_u0(prob, u0; defs = nothing, use_defaults = false) v = (v,) end for (kk, vv, ii) in zip(k, v, idx) + sym_to_idx[kk] = ii + kk = idx_to_vsym[ii] sym_to_idx[kk] = ii idx_to_sym[ii] = kk idx_to_val[ii] = vv @@ -612,6 +615,7 @@ end function fill_p(prob, p; defs = nothing, use_defaults = false) psyms = parameter_symbols(prob) + idx_to_psym = anydict(parameter_index(prob, sym) => sym for sym in psyms) sym_to_idx = anydict() idx_to_sym = anydict() idx_to_val = anydict() @@ -625,6 +629,8 @@ function fill_p(prob, p; defs = nothing, use_defaults = false) v = (v,) end for (kk, vv, ii) in zip(k, v, idx) + sym_to_idx[kk] = ii + kk = idx_to_psym[ii] sym_to_idx[kk] = ii idx_to_sym[ii] = kk idx_to_val[ii] = vv @@ -707,6 +713,9 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true}, t0) end varmap = merge(u0, p) + if is_time_dependent(prob) + varmap[only(independent_variable_symbols(prob))] = t0 + end for (k, v) in u0 u0[k] = symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, varmap) end diff --git a/test/downstream/modelingtoolkit_remake.jl b/test/downstream/modelingtoolkit_remake.jl index 38d379ddc..d0df44658 100644 --- a/test/downstream/modelingtoolkit_remake.jl +++ b/test/downstream/modelingtoolkit_remake.jl @@ -229,3 +229,22 @@ end prob2 = remake(prob; u0 = [y => 3.0], p = Dict()) @test prob2.ps[p] ≈ 4.0 end + +@testset "u0 dependent on parameter given as Symbol" begin + @variables x(t) + @parameters p + @mtkbuild sys = ODESystem([D(x) ~ x * p], t) + prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [p => 1.0]) + @test prob.ps[p] ≈ 1.0 + prob2 = remake(prob; u0 = [x => p], p = [:p => 2.0]) + @test prob2[x] ≈ 2.0 +end + +@testset "remake dependent on indepvar" begin + @variables x(t) + @parameters p + @mtkbuild sys = ODESystem([D(x) ~ x * p], t) + prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [p => 1.0]) + prob2 = remake(prob; u0 = [x => t + 3.0]) + @test prob2[x] ≈ 3.0 +end diff --git a/test/downstream/solution_interface.jl b/test/downstream/solution_interface.jl index 12e2c9402..cb6a07780 100644 --- a/test/downstream/solution_interface.jl +++ b/test/downstream/solution_interface.jl @@ -293,7 +293,7 @@ end prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 5.0), [p => 0.5, q => 0.0, r => 1.0, s => 10.0, u => 4096.0]) ss = SciMLBase.SavedSubsystem(sys, prob.p, [x, q, s, r]) - @test SciMLBase.get_saved_state_idxs(ss) == [xidx] + @test SciMLBase.get_saved_state_idxs(ss) == [variable_index(sys, x)] sswf = SciMLBase.SavedSubsystemWithFallback(ss, sys) xidx = variable_index(sys, x) qidx = parameter_index(sys, q) diff --git a/test/remake_tests.jl b/test/remake_tests.jl index eff20c58e..dd6fdae5f 100644 --- a/test/remake_tests.jl +++ b/test/remake_tests.jl @@ -297,8 +297,36 @@ a = Remake_Test1(p = 1) @test @inferred remake(a, kwargs = (; a = 1)) == Remake_Test1(p = 1, a = 1) @testset "fill_u0 and fill_p ignore identical variables with different names" begin - sys = SymbolCache(Dict(:x => 1, :x2 => 1, :y => 2), Dict(:a => 1, :a2 => 1, :b => 2), - :t; defaults = Dict(:x => 1, :y => 2, :a => 3, :b => 4)) + struct SCWrapper{S} + sc::S + end + SymbolicIndexingInterface.symbolic_container(s::SCWrapper) = s.sc + function SymbolicIndexingInterface.is_variable(s::SCWrapper, i::Symbol) + if i == :x2 + return is_variable(s.sc, :x) + end + is_variable(s.sc, i) + end + function SymbolicIndexingInterface.variable_index(s::SCWrapper, i::Symbol) + if i == :x2 + return variable_index(s.sc, :x) + end + variable_index(s.sc, i) + end + function SymbolicIndexingInterface.is_parameter(s::SCWrapper, i::Symbol) + if i == :a2 + return is_parameter(s.sc, :a) + end + is_parameter(s.sc, i) + end + function SymbolicIndexingInterface.parameter_index(s::SCWrapper, i::Symbol) + if i == :a2 + return parameter_index(s.sc, :a) + end + parameter_index(s.sc, i) + end + sys = SCWrapper(SymbolCache(Dict(:x => 1, :y => 2), Dict(:a => 1, :b => 2), + :t; defaults = Dict(:x => 1, :y => 2, :a => 3, :b => 4))) function foo(du, u, p, t) du .= u .* p end @@ -306,17 +334,17 @@ a = Remake_Test1(p = 1) u0 = Dict(:x2 => 2) newu0 = SciMLBase.fill_u0(prob, u0; defs = default_values(sys)) @test length(newu0) == 2 - @test get(newu0, :x2, 0) == 2 + @test get(newu0, :x, 0) == 2 @test get(newu0, :y, 0) == 2.5 p = Dict(:a2 => 3) newp = SciMLBase.fill_p(prob, p; defs = default_values(sys)) @test length(newp) == 2 - @test get(newp, :a2, 0) == 3 + @test get(newp, :a, 0) == 3 @test get(newp, :b, 0) == 4.5 end @testset "value of `nothing` is ignored" begin - sys = SymbolCache(Dict(:x => 1, :x2 => 1, :y => 2), Dict(:a => 1, :a2 => 1, :b => 2), + sys = SymbolCache(Dict(:x => 1, :y => 2), Dict(:a => 1, :b => 2), :t; defaults = Dict(:x => 1, :y => 2, :a => 3, :b => 4)) function foo(du, u, p, t) du .= u .* p