Skip to content

Commit dee5752

Browse files
Merge pull request #2391 from AayushSabharwal/as/sii-multivariate
refactor: handle NullParameters, fix independent_variable_symbols for multivariate systems
2 parents 35be92d + 1c063c0 commit dee5752

File tree

2 files changed

+100
-15
lines changed

2 files changed

+100
-15
lines changed

src/systems/abstractsystem.jl

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -185,75 +185,77 @@ end
185185
#Treat the result as a vector of symbols always
186186
function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym)
187187
if unwrap(sym) isa Int # [x, 1] coerces 1 to a Num
188-
return unwrap(sym) in 1:length(unknown_states(sys))
188+
return unwrap(sym) in 1:length(variable_symbols(sys))
189189
end
190-
return any(isequal(sym), unknown_states(sys)) ||
190+
return any(isequal(sym), variable_symbols(sys)) ||
191191
hasname(sym) && is_variable(sys, getname(sym))
192192
end
193193

194194
function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym::Symbol)
195-
return any(isequal(sym), getname.(unknown_states(sys))) ||
195+
return any(isequal(sym), getname.(variable_symbols(sys))) ||
196196
count('', string(sym)) == 1 &&
197-
count(isequal(sym), Symbol.(sys.name, :₊, getname.(unknown_states(sys)))) == 1
197+
count(isequal(sym), Symbol.(sys.name, :₊, getname.(variable_symbols(sys)))) == 1
198198
end
199199

200200
function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym)
201201
if unwrap(sym) isa Int
202202
return unwrap(sym)
203203
end
204-
idx = findfirst(isequal(sym), unknown_states(sys))
204+
idx = findfirst(isequal(sym), variable_symbols(sys))
205205
if idx === nothing && hasname(sym)
206206
idx = variable_index(sys, getname(sym))
207207
end
208208
return idx
209209
end
210210

211211
function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym::Symbol)
212-
idx = findfirst(isequal(sym), getname.(unknown_states(sys)))
212+
idx = findfirst(isequal(sym), getname.(variable_symbols(sys)))
213213
if idx !== nothing
214214
return idx
215215
elseif count('', string(sym)) == 1
216-
return findfirst(isequal(sym), Symbol.(sys.name, :₊, getname.(unknown_states(sys))))
216+
return findfirst(isequal(sym), Symbol.(sys.name, :₊, getname.(variable_symbols(sys))))
217217
end
218218
return nothing
219219
end
220220

221+
SymbolicIndexingInterface.variable_symbols(sys::AbstractMultivariateSystem) = sys.dvs
222+
221223
function SymbolicIndexingInterface.variable_symbols(sys::AbstractSystem)
222224
return unknown_states(sys)
223225
end
224226

225227
function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
226228
if unwrap(sym) isa Int
227-
return unwrap(sym) in 1:length(parameters(sys))
229+
return unwrap(sym) in 1:length(parameter_symbols(sys))
228230
end
229231

230-
return any(isequal(sym), parameters(sys)) ||
232+
return any(isequal(sym), parameter_symbols(sys)) ||
231233
hasname(sym) && is_parameter(sys, getname(sym))
232234
end
233235

234236
function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Symbol)
235-
return any(isequal(sym), getname.(parameters(sys))) ||
237+
return any(isequal(sym), getname.(parameter_symbols(sys))) ||
236238
count('', string(sym)) == 1 &&
237-
count(isequal(sym), Symbol.(sys.name, :₊, getname.(parameters(sys)))) == 1
239+
count(isequal(sym), Symbol.(sys.name, :₊, getname.(parameter_symbols(sys)))) == 1
238240
end
239241

240242
function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
241243
if unwrap(sym) isa Int
242244
return unwrap(sym)
243245
end
244-
idx = findfirst(isequal(sym), parameters(sys))
246+
idx = findfirst(isequal(sym), parameter_symbols(sys))
245247
if idx === nothing && hasname(sym)
246248
idx = parameter_index(sys, getname(sym))
247249
end
248250
return idx
249251
end
250252

251253
function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Symbol)
252-
idx = findfirst(isequal(sym), getname.(parameters(sys)))
254+
idx = findfirst(isequal(sym), getname.(parameter_symbols(sys)))
253255
if idx !== nothing
254256
return idx
255257
elseif count('', string(sym)) == 1
256-
return findfirst(isequal(sym), Symbol.(sys.name, :₊, getname.(parameters(sys))))
258+
return findfirst(isequal(sym), Symbol.(sys.name, :₊, getname.(parameter_symbols(sys))))
257259
end
258260
return nothing
259261
end
@@ -263,7 +265,7 @@ function SymbolicIndexingInterface.parameter_symbols(sys::AbstractSystem)
263265
end
264266

265267
function SymbolicIndexingInterface.is_independent_variable(sys::AbstractSystem, sym)
266-
return any(isequal(sym), independent_variables(sys))
268+
return any(isequal(sym), independent_variable_symbols(sys))
267269
end
268270

269271
function SymbolicIndexingInterface.is_independent_variable(sys::AbstractSystem, sym::Symbol)
@@ -650,6 +652,12 @@ end
650652

651653
function parameters(sys::AbstractSystem)
652654
ps = get_ps(sys)
655+
if ps == SciMLBase.NullParameters()
656+
return []
657+
end
658+
if eltype(ps) <: Pair
659+
ps = first.(ps)
660+
end
653661
systems = get_systems(sys)
654662
unique(isempty(systems) ? ps : [ps; reduce(vcat, namespace_parameters.(systems))])
655663
end

test/symbolic_indexing_interface.jl

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
using ModelingToolkit, SymbolicIndexingInterface, SciMLBase
2+
3+
@parameters t a b
4+
@variables x(t) y(t)
5+
D = Differential(t)
6+
eqs = [D(x) ~ a * y + t, D(y) ~ b * t]
7+
@named odesys = ODESystem(eqs, t, [x, y], [a, b])
8+
9+
@test all(is_variable.((odesys,), [x, y, 1, 2, :x, :y]))
10+
@test all(.!is_variable.((odesys,), [a, b, t, 3, 0, :a, :b]))
11+
@test variable_index.((odesys,), [x, y, a, b, t, 1, 2, :x, :y, :a, :b]) == [1, 2, nothing, nothing, nothing, 1, 2, 1, 2, nothing, nothing]
12+
@test isequal(variable_symbols(odesys), [x, y])
13+
@test all(is_parameter.((odesys,), [a, b, 1, 2, :a, :b]))
14+
@test all(.!is_parameter.((odesys,), [x, y, t, 3, 0, :x, :y]))
15+
@test parameter_index.((odesys,), [x, y, a, b, t, 1, 2, :x, :y, :a, :b]) == [nothing, nothing, 1, 2, nothing, 1, 2, nothing, nothing, 1, 2]
16+
@test isequal(parameter_symbols(odesys), [a, b])
17+
@test all(is_independent_variable.((odesys,), [t, :t]))
18+
@test all(.!is_independent_variable.((odesys,), [x, y, a, :x, :y, :a]))
19+
@test isequal(independent_variable_symbols(odesys), [t])
20+
@test is_time_dependent(odesys)
21+
@test constant_structure(odesys)
22+
23+
@variables x y z
24+
@parameters σ ρ β
25+
26+
eqs = [0 ~ σ*(y-x),
27+
0 ~ x*-z)-y,
28+
0 ~ x*y - β*z]
29+
@named ns = NonlinearSystem(eqs, [x,y,z],[σ,ρ,β])
30+
31+
@test !is_time_dependent(ns)
32+
33+
@parameters x
34+
@variables t u(..)
35+
Dxx = Differential(x)^2
36+
Dtt = Differential(t)^2
37+
Dt = Differential(t)
38+
39+
#2D PDE
40+
C=1
41+
eq = Dtt(u(t,x)) ~ C^2*Dxx(u(t,x))
42+
43+
# Initial and boundary conditions
44+
bcs = [u(t,0) ~ 0.,# for all t > 0
45+
u(t,1) ~ 0.,# for all t > 0
46+
u(0,x) ~ x*(1. - x), #for all 0 < x < 1
47+
Dt(u(0,x)) ~ 0. ] #for all 0 < x < 1]
48+
49+
# Space and time domains
50+
domains = [t (0.0,1.0),
51+
x (0.0,1.0)]
52+
53+
@named pde_system = PDESystem(eq,bcs,domains,[t,x],[u])
54+
55+
@test pde_system.ps == SciMLBase.NullParameters()
56+
@test parameter_symbols(pde_system) == []
57+
58+
@parameters t x
59+
@constants h = 1
60+
@variables u(..)
61+
Dt = Differential(t)
62+
Dxx = Differential(x)^2
63+
eq = Dt(u(t, x)) ~ h * Dxx(u(t, x))
64+
bcs = [u(0, x) ~ -h * x * (x - 1) * sin(x),
65+
u(t, 0) ~ 0, u(t, 1) ~ 0]
66+
67+
domains = [t (0.0, 1.0),
68+
x (0.0, 1.0)]
69+
70+
analytic = [u(t, x) ~ -h * x * (x - 1) * sin(x) * exp(-2 * h * t)]
71+
analytic_function = (ps, t, x) -> -ps[1] * x * (x - 1) * sin(x) * exp(-2 * ps[1] * t)
72+
73+
@named pdesys = PDESystem(eq, bcs, domains, [t, x], [u], [h => 1], analytic = analytic)
74+
75+
@test isequal(pdesys.ps, [h => 1])
76+
@test isequal(parameter_symbols(pdesys), [h])
77+
@test isequal(parameters(pdesys), [h])

0 commit comments

Comments
 (0)