Skip to content

Commit 05710fe

Browse files
Merge pull request #2345 from AayushSabharwal/as/indexing-rework
feat: implementation of new SymbolicIndexingInterface
2 parents e7fe1b5 + d179b4f commit 05710fe

File tree

4 files changed

+114
-32
lines changed

4 files changed

+114
-32
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ MacroTools = "0.5"
8888
NaNMath = "0.3, 1"
8989
OrdinaryDiffEq = "6"
9090
PrecompileTools = "1"
91-
RecursiveArrayTools = "2.3"
91+
RecursiveArrayTools = "2.3, 3"
9292
Reexport = "0.2, 1"
9393
RuntimeGeneratedFunctions = "0.5.9"
9494
SciMLBase = "2.0.1"
@@ -98,7 +98,7 @@ SimpleNonlinearSolve = "0.1.0, 1"
9898
SparseArrays = "1"
9999
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
100100
StaticArrays = "0.10, 0.11, 0.12, 1.0"
101-
SymbolicIndexingInterface = "0.1, 0.2"
101+
SymbolicIndexingInterface = "0.3"
102102
SymbolicUtils = "1.0"
103103
Symbolics = "5.7"
104104
URIs = "1"

src/ModelingToolkit.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@ using PrecompileTools, Reexport
3535

3636
using RecursiveArrayTools
3737

38-
import SymbolicIndexingInterface
39-
import SymbolicIndexingInterface: independent_variables, states, parameters
38+
using SymbolicIndexingInterface
4039
export independent_variables, states, parameters
4140
import SymbolicUtils
4241
import SymbolicUtils: istree, arguments, operation, similarterm, promote_symtype,

src/systems/abstractsystem.jl

Lines changed: 109 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,18 @@ function independent_variable(sys::AbstractSystem)
161161
isdefined(sys, :iv) ? getfield(sys, :iv) : nothing
162162
end
163163

164-
#Treat the result as a vector of symbols always
165-
function SymbolicIndexingInterface.independent_variables(sys::AbstractSystem)
166-
systype = typeof(sys)
167-
@warn "Please declare ($systype) as a subtype of `AbstractTimeDependentSystem`, `AbstractTimeIndependentSystem` or `AbstractMultivariateSystem`."
164+
function independent_variables(sys::AbstractTimeDependentSystem)
165+
return [getfield(sys, :iv)]
166+
end
167+
168+
independent_variables(::AbstractTimeIndependentSystem) = []
169+
170+
function independent_variables(sys::AbstractMultivariateSystem)
171+
return getfield(sys, :ivs)
172+
end
173+
174+
function independent_variables(sys::AbstractSystem)
175+
@warn "Please declare ($(typeof(sys))) as a subtype of `AbstractTimeDependentSystem`, `AbstractTimeIndependentSystem` or `AbstractMultivariateSystem`."
168176
if isdefined(sys, :iv)
169177
return [getfield(sys, :iv)]
170178
elseif isdefined(sys, :ivs)
@@ -174,14 +182,102 @@ function SymbolicIndexingInterface.independent_variables(sys::AbstractSystem)
174182
end
175183
end
176184

177-
function SymbolicIndexingInterface.independent_variables(sys::AbstractTimeDependentSystem)
178-
[getfield(sys, :iv)]
185+
#Treat the result as a vector of symbols always
186+
function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym)
187+
if unwrap(sym) isa Int # [x, 1] coerces 1 to a Num
188+
return unwrap(sym) in 1:length(unknown_states(sys))
189+
end
190+
return any(isequal(sym), unknown_states(sys)) || hasname(sym) && is_variable(sys, getname(sym))
191+
end
192+
193+
function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym::Symbol)
194+
return any(isequal(sym), getname.(unknown_states(sys))) || count('', string(sym)) == 1 && count(isequal(sym), Symbol.(sys.name, :₊, getname.(unknown_states(sys)))) == 1
179195
end
180-
SymbolicIndexingInterface.independent_variables(sys::AbstractTimeIndependentSystem) = []
181-
function SymbolicIndexingInterface.independent_variables(sys::AbstractMultivariateSystem)
182-
getfield(sys, :ivs)
196+
197+
function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym)
198+
if unwrap(sym) isa Int
199+
return unwrap(sym)
200+
end
201+
idx = findfirst(isequal(sym), unknown_states(sys))
202+
if idx === nothing && hasname(sym)
203+
idx = variable_index(sys, getname(sym))
204+
end
205+
return idx
183206
end
184207

208+
function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym::Symbol)
209+
idx = findfirst(isequal(sym), getname.(unknown_states(sys)))
210+
if idx !== nothing
211+
return idx
212+
elseif count('', string(sym)) == 1
213+
return findfirst(isequal(sym), Symbol.(sys.name, :₊, getname.(unknown_states(sys))))
214+
end
215+
return nothing
216+
end
217+
218+
function SymbolicIndexingInterface.variable_symbols(sys::AbstractSystem)
219+
return unknown_states(sys)
220+
end
221+
222+
function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
223+
if unwrap(sym) isa Int
224+
return unwrap(sym) in 1:length(parameters(sys))
225+
end
226+
227+
return any(isequal(sym), parameters(sys)) || hasname(sym) && is_parameter(sys, getname(sym))
228+
end
229+
230+
function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Symbol)
231+
return any(isequal(sym), getname.(parameters(sys))) ||
232+
count('', string(sym)) == 1 && count(isequal(sym), Symbol.(sys.name, :₊, getname.(parameters(sys)))) == 1
233+
end
234+
235+
function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
236+
if unwrap(sym) isa Int
237+
return unwrap(sym)
238+
end
239+
idx = findfirst(isequal(sym), parameters(sys))
240+
if idx === nothing && hasname(sym)
241+
idx = parameter_index(sys, getname(sym))
242+
end
243+
return idx
244+
end
245+
246+
function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Symbol)
247+
idx = findfirst(isequal(sym), getname.(parameters(sys)))
248+
if idx !== nothing
249+
return idx
250+
elseif count('', string(sym)) == 1
251+
return findfirst(isequal(sym), Symbol.(sys.name, :₊, getname.(parameters(sys))))
252+
end
253+
return nothing
254+
end
255+
256+
function SymbolicIndexingInterface.parameter_symbols(sys::AbstractSystem)
257+
return parameters(sys)
258+
end
259+
260+
function SymbolicIndexingInterface.is_independent_variable(sys::AbstractSystem, sym)
261+
return any(isequal(sym), independent_variables(sys))
262+
end
263+
264+
function SymbolicIndexingInterface.is_independent_variable(sys::AbstractSystem, sym::Symbol)
265+
return any(isequal(sym), getname.(independent_variables(sys)))
266+
end
267+
268+
function SymbolicIndexingInterface.independent_variable_symbols(sys::AbstractSystem)
269+
return independent_variables(sys)
270+
end
271+
272+
function SymbolicIndexingInterface.is_observed(sys::AbstractSystem, sym)
273+
return !is_variable(sys, sym) && !is_parameter(sys, sym) && !is_independent_variable(sys, sym) && symbolic_type(sym) != NotSymbolic()
274+
end
275+
276+
SymbolicIndexingInterface.is_time_dependent(::AbstractTimeDependentSystem) = true
277+
SymbolicIndexingInterface.is_time_dependent(::AbstractTimeIndependentSystem) = false
278+
279+
SymbolicIndexingInterface.constant_structure(::AbstractSystem) = true
280+
185281
iscomplete(sys::AbstractSystem) = isdefined(sys, :complete) && getfield(sys, :complete)
186282

187283
"""
@@ -534,12 +630,15 @@ function states(sys::AbstractSystem)
534630
[sts; reduce(vcat, namespace_variables.(systems))])
535631
end
536632

537-
function SymbolicIndexingInterface.parameters(sys::AbstractSystem)
633+
function parameters(sys::AbstractSystem)
538634
ps = get_ps(sys)
539635
systems = get_systems(sys)
540636
unique(isempty(systems) ? ps : [ps; reduce(vcat, namespace_parameters.(systems))])
541637
end
542638

639+
# required in `src/connectors.jl:437`
640+
parameters(_) = []
641+
543642
function controls(sys::AbstractSystem)
544643
ctrls = get_ctrls(sys)
545644
systems = get_systems(sys)
@@ -638,8 +737,6 @@ function time_varying_as_func(x, sys::AbstractTimeDependentSystem)
638737
return x
639738
end
640739

641-
SymbolicIndexingInterface.is_indep_sym(sys::AbstractSystem, sym) = isequal(sym, get_iv(sys))
642-
643740
"""
644741
$(SIGNATURES)
645742
@@ -653,20 +750,6 @@ function unknown_states(sys::AbstractSystem)
653750
return sts
654751
end
655752

656-
function SymbolicIndexingInterface.state_sym_to_index(sys::AbstractSystem, sym)
657-
findfirst(isequal(sym), unknown_states(sys))
658-
end
659-
function SymbolicIndexingInterface.is_state_sym(sys::AbstractSystem, sym)
660-
!isnothing(SymbolicIndexingInterface.state_sym_to_index(sys, sym))
661-
end
662-
663-
function SymbolicIndexingInterface.param_sym_to_index(sys::AbstractSystem, sym)
664-
findfirst(isequal(sym), SymbolicIndexingInterface.parameters(sys))
665-
end
666-
function SymbolicIndexingInterface.is_param_sym(sys::AbstractSystem, sym)
667-
!isnothing(SymbolicIndexingInterface.param_sym_to_index(sys, sym))
668-
end
669-
670753
###
671754
### System utils
672755
###

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -516,9 +516,9 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
516516
tgrad = _tgrad === nothing ? nothing : _tgrad,
517517
mass_matrix = _M,
518518
jac_prototype = jac_prototype,
519-
syms = Symbol.(states(sys)),
519+
syms = collect(Symbol.(states(sys))),
520520
indepsym = Symbol(get_iv(sys)),
521-
paramsyms = Symbol.(ps),
521+
paramsyms = collect(Symbol.(ps)),
522522
observed = observedfun,
523523
sparsity = sparsity ? jacobian_sparsity(sys) : nothing,
524524
analytic = analytic)

0 commit comments

Comments
 (0)