Skip to content

Commit ad5c511

Browse files
refactor: handle NullParameters, fix independent_variable_symbols for multivariate systems
1 parent 14679d3 commit ad5c511

File tree

1 file changed

+20
-15
lines changed

1 file changed

+20
-15
lines changed

src/systems/abstractsystem.jl

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -185,75 +185,77 @@ end
185185
#Treat the result as a vector of symbols always
186186
function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym)
187187
if unwrap(sym) isa Int # [x, 1] coerces 1 to a Num
188-
return unwrap(sym) in 1:length(unknown_states(sys))
188+
return unwrap(sym) in 1:length(variable_symbols(sys))
189189
end
190-
return any(isequal(sym), unknown_states(sys)) ||
190+
return any(isequal(sym), variable_symbols(sys)) ||
191191
hasname(sym) && is_variable(sys, getname(sym))
192192
end
193193

194194
function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym::Symbol)
195-
return any(isequal(sym), getname.(unknown_states(sys))) ||
195+
return any(isequal(sym), getname.(variable_symbols(sys))) ||
196196
count('', string(sym)) == 1 &&
197-
count(isequal(sym), Symbol.(sys.name, :₊, getname.(unknown_states(sys)))) == 1
197+
count(isequal(sym), Symbol.(sys.name, :₊, getname.(variable_symbols(sys)))) == 1
198198
end
199199

200200
function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym)
201201
if unwrap(sym) isa Int
202202
return unwrap(sym)
203203
end
204-
idx = findfirst(isequal(sym), unknown_states(sys))
204+
idx = findfirst(isequal(sym), variable_symbols(sys))
205205
if idx === nothing && hasname(sym)
206206
idx = variable_index(sys, getname(sym))
207207
end
208208
return idx
209209
end
210210

211211
function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym::Symbol)
212-
idx = findfirst(isequal(sym), getname.(unknown_states(sys)))
212+
idx = findfirst(isequal(sym), getname.(variable_symbols(sys)))
213213
if idx !== nothing
214214
return idx
215215
elseif count('', string(sym)) == 1
216-
return findfirst(isequal(sym), Symbol.(sys.name, :₊, getname.(unknown_states(sys))))
216+
return findfirst(isequal(sym), Symbol.(sys.name, :₊, getname.(variable_symbols(sys))))
217217
end
218218
return nothing
219219
end
220220

221+
SymbolicIndexingInterface.variable_symbols(sys::AbstractMultivariateSystem) = sys.dvs
222+
221223
function SymbolicIndexingInterface.variable_symbols(sys::AbstractSystem)
222224
return unknown_states(sys)
223225
end
224226

225227
function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
226228
if unwrap(sym) isa Int
227-
return unwrap(sym) in 1:length(parameters(sys))
229+
return unwrap(sym) in 1:length(parameter_symbols(sys))
228230
end
229231

230-
return any(isequal(sym), parameters(sys)) ||
232+
return any(isequal(sym), parameter_symbols(sys)) ||
231233
hasname(sym) && is_parameter(sys, getname(sym))
232234
end
233235

234236
function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Symbol)
235-
return any(isequal(sym), getname.(parameters(sys))) ||
237+
return any(isequal(sym), getname.(parameter_symbols(sys))) ||
236238
count('', string(sym)) == 1 &&
237-
count(isequal(sym), Symbol.(sys.name, :₊, getname.(parameters(sys)))) == 1
239+
count(isequal(sym), Symbol.(sys.name, :₊, getname.(parameter_symbols(sys)))) == 1
238240
end
239241

240242
function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
241243
if unwrap(sym) isa Int
242244
return unwrap(sym)
243245
end
244-
idx = findfirst(isequal(sym), parameters(sys))
246+
idx = findfirst(isequal(sym), parameter_symbols(sys))
245247
if idx === nothing && hasname(sym)
246248
idx = parameter_index(sys, getname(sym))
247249
end
248250
return idx
249251
end
250252

251253
function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Symbol)
252-
idx = findfirst(isequal(sym), getname.(parameters(sys)))
254+
idx = findfirst(isequal(sym), getname.(parameter_symbols(sys)))
253255
if idx !== nothing
254256
return idx
255257
elseif count('', string(sym)) == 1
256-
return findfirst(isequal(sym), Symbol.(sys.name, :₊, getname.(parameters(sys))))
258+
return findfirst(isequal(sym), Symbol.(sys.name, :₊, getname.(parameter_symbols(sys))))
257259
end
258260
return nothing
259261
end
@@ -263,7 +265,7 @@ function SymbolicIndexingInterface.parameter_symbols(sys::AbstractSystem)
263265
end
264266

265267
function SymbolicIndexingInterface.is_independent_variable(sys::AbstractSystem, sym)
266-
return any(isequal(sym), independent_variables(sys))
268+
return any(isequal(sym), independent_variable_symbols(sys))
267269
end
268270

269271
function SymbolicIndexingInterface.is_independent_variable(sys::AbstractSystem, sym::Symbol)
@@ -650,6 +652,9 @@ end
650652

651653
function parameters(sys::AbstractSystem)
652654
ps = get_ps(sys)
655+
if ps == SciMLBase.NullParameters()
656+
return []
657+
end
653658
systems = get_systems(sys)
654659
unique(isempty(systems) ? ps : [ps; reduce(vcat, namespace_parameters.(systems))])
655660
end

0 commit comments

Comments
 (0)