Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ LinearAlgebra = "1.10"
Logging = "1.10"
Makie = "0.20, 0.21"
Markdown = "1.10"
ModelingToolkit = "8.75, 9"
PartialFunctions = "1.1"
PrecompileTools = "1.2"
Preferences = "1.3"
Expand Down Expand Up @@ -100,7 +99,6 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand All @@ -118,4 +116,4 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Pkg", "Plots", "UnicodePlots", "PyCall", "PythonCall", "SafeTestsets", "Serialization", "Test", "StableRNGs", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "ModelingToolkit", "OrdinaryDiffEq", "ForwardDiff"]
test = ["Pkg", "Plots", "UnicodePlots", "PyCall", "PythonCall", "SafeTestsets", "Serialization", "Test", "StableRNGs", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "OrdinaryDiffEq", "ForwardDiff"]
26 changes: 20 additions & 6 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -549,9 +549,16 @@ function fill_u0(prob, u0; defs = nothing, use_defaults = false)
for (k, v) in u0
idx = variable_index(prob, k)
idx === nothing && continue
sym_to_idx[k] = idx
idx_to_sym[idx] = k
idx_to_val[idx] = v
if !(idx isa AbstractArray) || symbolic_type(k) != ArraySymbolic()
idx = (idx,)
k = (k,)
v = (v,)
end
for (kk, vv, ii) in zip(k, v, idx)
sym_to_idx[kk] = ii
idx_to_sym[ii] = kk
idx_to_val[ii] = vv
end
end
for sym in vsyms
haskey(sym_to_idx, sym) && continue
Expand Down Expand Up @@ -586,9 +593,16 @@ function fill_p(prob, p; defs = nothing, use_defaults = false)
for (k, v) in p
idx = parameter_index(prob, k)
idx === nothing && continue
sym_to_idx[k] = idx
idx_to_sym[idx] = k
idx_to_val[idx] = v
if !(idx isa AbstractArray) || symbolic_type(k) != ArraySymbolic()
idx = (idx,)
k = (k,)
v = (v,)
end
for (kk, vv, ii) in zip(k, v, idx)
sym_to_idx[kk] = ii
idx_to_sym[ii] = kk
idx_to_val[ii] = vv
end
end
for sym in psyms
haskey(sym_to_idx, sym) && continue
Expand Down
28 changes: 28 additions & 0 deletions test/downstream/modelingtoolkit_remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,31 @@ sol = solve(prob, BFGS())
@test prob2.ps[P] == sign * 2.0
end
end

@testset "remake with Vector{Int} as index of array variable/parameter" begin
@parameters k[1:4]
@variables (V(t))[1:2]
function rhs!(du, u, p, t)
du[1] = p[1] - p[2] * u[1]
du[2] = p[3] - p[4] * u[2]
nothing
end
sys = SymbolCache(Dict(V => 1:2, V[1] => 1, V[2] => 2),
Dict(k => 1:4, k[1] => 1, k[2] => 2, k[3] => 3, k[4] => 4), t)
struct SCWrapper{S}
sys::S
end
SymbolicIndexingInterface.symbolic_container(s::SCWrapper) = s.sys
SymbolicIndexingInterface.variable_symbols(s::SCWrapper) = filter(
x -> symbolic_type(x) != ArraySymbolic(), variable_symbols(s.sys))
SymbolicIndexingInterface.parameter_symbols(s::SCWrapper) = filter(
x -> symbolic_type(x) != ArraySymbolic(), parameter_symbols(s.sys))
sys = SCWrapper(sys)
fn = ODEFunction(rhs!; sys)
oprob_scal_scal = ODEProblem(fn, [10.0, 20.0], (0.0, 1.0), [1.0, 2.0, 3.0, 4.0])
ps_vec = [k => [2.0, 3.0, 4.0, 5.0]]
u0_vec = [V => [1.5, 2.5]]
newoprob = remake(oprob_scal_scal; u0 = u0_vec, p = ps_vec)
@test newoprob.ps[k] == [2.0, 3.0, 4.0, 5.0]
@test newoprob[V] == [1.5, 2.5]
end
17 changes: 8 additions & 9 deletions test/traits.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using SciMLBase, Test
using ModelingToolkit, OrdinaryDiffEq, DataFrames
using ModelingToolkit: t_nounits as t, D_nounits as D
using OrdinaryDiffEq, DataFrames, SymbolicIndexingInterface

@test SciMLBase.Tables.isrowtable(ODESolution)
@test SciMLBase.Tables.isrowtable(RODESolution)
Expand All @@ -10,13 +9,13 @@ using ModelingToolkit: t_nounits as t, D_nounits as D
@test !SciMLBase.Tables.isrowtable(SciMLBase.QuadratureSolution)
@test !SciMLBase.Tables.isrowtable(SciMLBase.OptimizationSolution)

@variables x(t) = 1
eqs = [D(x) ~ -x]
@named sys = ODESystem(eqs, t)
sys = complete(sys)
prob = ODEProblem(sys)
sol = solve(prob, Tsit5(), tspan = (0.0, 1.0))
function rhs(u, p, t)
return -u
end
sys = SymbolCache([:x], Symbol[], :t)
prob = ODEProblem(ODEFunction(rhs; sys), [1.0], (0.0, 1.0))
sol = solve(prob, Tsit5())
df = DataFrame(sol)
@test size(df) == (length(sol.u), 2)
@test df.timestamp == sol.t
@test df.x == sol[x]
@test df.x == sol[:x]
Loading