Skip to content

Commit 561c781

Browse files
Merge pull request #801 from AayushSabharwal/as/remake-arridx
fix: support remake with array symbolic whose index is an array of indices
2 parents 06864fd + 5c029b7 commit 561c781

File tree

4 files changed

+57
-18
lines changed

4 files changed

+57
-18
lines changed

Project.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ LinearAlgebra = "1.10"
7171
Logging = "1.10"
7272
Makie = "0.20, 0.21"
7373
Markdown = "1.10"
74-
ModelingToolkit = "8.75, 9"
7574
PartialFunctions = "1.1"
7675
PrecompileTools = "1.2"
7776
Preferences = "1.3"
@@ -100,7 +99,6 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
10099
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
101100
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
102101
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
103-
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
104102
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
105103
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
106104
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
@@ -118,4 +116,4 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
118116
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
119117

120118
[targets]
121-
test = ["Pkg", "Plots", "UnicodePlots", "PyCall", "PythonCall", "SafeTestsets", "Serialization", "Test", "StableRNGs", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "ModelingToolkit", "OrdinaryDiffEq", "ForwardDiff"]
119+
test = ["Pkg", "Plots", "UnicodePlots", "PyCall", "PythonCall", "SafeTestsets", "Serialization", "Test", "StableRNGs", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "OrdinaryDiffEq", "ForwardDiff"]

src/remake.jl

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -549,9 +549,16 @@ function fill_u0(prob, u0; defs = nothing, use_defaults = false)
549549
for (k, v) in u0
550550
idx = variable_index(prob, k)
551551
idx === nothing && continue
552-
sym_to_idx[k] = idx
553-
idx_to_sym[idx] = k
554-
idx_to_val[idx] = v
552+
if !(idx isa AbstractArray) || symbolic_type(k) != ArraySymbolic()
553+
idx = (idx,)
554+
k = (k,)
555+
v = (v,)
556+
end
557+
for (kk, vv, ii) in zip(k, v, idx)
558+
sym_to_idx[kk] = ii
559+
idx_to_sym[ii] = kk
560+
idx_to_val[ii] = vv
561+
end
555562
end
556563
for sym in vsyms
557564
haskey(sym_to_idx, sym) && continue
@@ -586,9 +593,16 @@ function fill_p(prob, p; defs = nothing, use_defaults = false)
586593
for (k, v) in p
587594
idx = parameter_index(prob, k)
588595
idx === nothing && continue
589-
sym_to_idx[k] = idx
590-
idx_to_sym[idx] = k
591-
idx_to_val[idx] = v
596+
if !(idx isa AbstractArray) || symbolic_type(k) != ArraySymbolic()
597+
idx = (idx,)
598+
k = (k,)
599+
v = (v,)
600+
end
601+
for (kk, vv, ii) in zip(k, v, idx)
602+
sym_to_idx[kk] = ii
603+
idx_to_sym[ii] = kk
604+
idx_to_val[ii] = vv
605+
end
592606
end
593607
for sym in psyms
594608
haskey(sym_to_idx, sym) && continue

test/downstream/modelingtoolkit_remake.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,3 +191,31 @@ sol = solve(prob, BFGS())
191191
@test prob2.ps[P] == sign * 2.0
192192
end
193193
end
194+
195+
@testset "remake with Vector{Int} as index of array variable/parameter" begin
196+
@parameters k[1:4]
197+
@variables (V(t))[1:2]
198+
function rhs!(du, u, p, t)
199+
du[1] = p[1] - p[2] * u[1]
200+
du[2] = p[3] - p[4] * u[2]
201+
nothing
202+
end
203+
sys = SymbolCache(Dict(V => 1:2, V[1] => 1, V[2] => 2),
204+
Dict(k => 1:4, k[1] => 1, k[2] => 2, k[3] => 3, k[4] => 4), t)
205+
struct SCWrapper{S}
206+
sys::S
207+
end
208+
SymbolicIndexingInterface.symbolic_container(s::SCWrapper) = s.sys
209+
SymbolicIndexingInterface.variable_symbols(s::SCWrapper) = filter(
210+
x -> symbolic_type(x) != ArraySymbolic(), variable_symbols(s.sys))
211+
SymbolicIndexingInterface.parameter_symbols(s::SCWrapper) = filter(
212+
x -> symbolic_type(x) != ArraySymbolic(), parameter_symbols(s.sys))
213+
sys = SCWrapper(sys)
214+
fn = ODEFunction(rhs!; sys)
215+
oprob_scal_scal = ODEProblem(fn, [10.0, 20.0], (0.0, 1.0), [1.0, 2.0, 3.0, 4.0])
216+
ps_vec = [k => [2.0, 3.0, 4.0, 5.0]]
217+
u0_vec = [V => [1.5, 2.5]]
218+
newoprob = remake(oprob_scal_scal; u0 = u0_vec, p = ps_vec)
219+
@test newoprob.ps[k] == [2.0, 3.0, 4.0, 5.0]
220+
@test newoprob[V] == [1.5, 2.5]
221+
end

test/traits.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
using SciMLBase, Test
2-
using ModelingToolkit, OrdinaryDiffEq, DataFrames
3-
using ModelingToolkit: t_nounits as t, D_nounits as D
2+
using OrdinaryDiffEq, DataFrames, SymbolicIndexingInterface
43

54
@test SciMLBase.Tables.isrowtable(ODESolution)
65
@test SciMLBase.Tables.isrowtable(RODESolution)
@@ -10,13 +9,13 @@ using ModelingToolkit: t_nounits as t, D_nounits as D
109
@test !SciMLBase.Tables.isrowtable(SciMLBase.QuadratureSolution)
1110
@test !SciMLBase.Tables.isrowtable(SciMLBase.OptimizationSolution)
1211

13-
@variables x(t) = 1
14-
eqs = [D(x) ~ -x]
15-
@named sys = ODESystem(eqs, t)
16-
sys = complete(sys)
17-
prob = ODEProblem(sys)
18-
sol = solve(prob, Tsit5(), tspan = (0.0, 1.0))
12+
function rhs(u, p, t)
13+
return -u
14+
end
15+
sys = SymbolCache([:x], Symbol[], :t)
16+
prob = ODEProblem(ODEFunction(rhs; sys), [1.0], (0.0, 1.0))
17+
sol = solve(prob, Tsit5())
1918
df = DataFrame(sol)
2019
@test size(df) == (length(sol.u), 2)
2120
@test df.timestamp == sol.t
22-
@test df.x == sol[x]
21+
@test df.x == sol[:x]

0 commit comments

Comments
 (0)