Skip to content

Commit e1d27d7

Browse files
Merge pull request #1988 from AayushSabharwal/syms_rework
Symbolic indexing rework
2 parents 979fc24 + 269519f commit e1d27d7

File tree

5 files changed

+39
-33
lines changed

5 files changed

+39
-33
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
4040
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
4141
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
4242
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
43+
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
4344
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
4445
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
4546
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
@@ -76,6 +77,7 @@ Setfield = "0.7, 0.8, 1"
7677
SimpleNonlinearSolve = "0.1.0"
7778
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
7879
StaticArrays = "0.10, 0.11, 0.12, 1.0"
80+
SymbolicIndexingInterface = "0.1"
7981
SymbolicUtils = "0.19"
8082
Symbolics = "4.9"
8183
UnPack = "0.1, 1.0"

src/ModelingToolkit.jl

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ RuntimeGeneratedFunctions.init(@__MODULE__)
3434

3535
using RecursiveArrayTools
3636

37+
import SymbolicIndexingInterface
38+
import SymbolicIndexingInterface: independent_variables, states, parameters
39+
export independent_variables, states, parameters
3740
import SymbolicUtils
3841
import SymbolicUtils: istree, arguments, operation, similarterm, promote_symtype,
3942
Symbolic, Term, Add, Mul, Pow, Sym, FnType,
@@ -96,29 +99,8 @@ abstract type AbstractODESystem <: AbstractTimeDependentSystem end
9699
abstract type AbstractMultivariateSystem <: AbstractSystem end
97100
abstract type AbstractOptimizationSystem <: AbstractTimeIndependentSystem end
98101

99-
"""
100-
$(TYPEDSIGNATURES)
101-
102-
Get the set of independent variables for the given system.
103-
"""
104-
function independent_variables end
105-
106102
function independent_variable end
107103

108-
"""
109-
$(TYPEDSIGNATURES)
110-
111-
Get the set of states for the given system.
112-
"""
113-
function states end
114-
115-
"""
116-
$(TYPEDSIGNATURES)
117-
118-
Get the set of parameters variables for the given system.
119-
"""
120-
function parameters end
121-
122104
# this has to be included early to deal with depency issues
123105
include("structural_transformation/bareiss.jl")
124106
function complete end
@@ -203,7 +185,7 @@ export Differential, expand_derivatives, @derivatives
203185
export Equation, ConstrainedEquation
204186
export Term, Sym
205187
export SymScope, LocalScope, ParentScope, DelayParentScope, GlobalScope
206-
export independent_variables, independent_variable, states, parameters, equations, controls,
188+
export independent_variable, equations, controls,
207189
observed, structure, full_equations
208190
export structural_simplify, expand_connections, linearize, linearization_function
209191
export DiscreteSystem, DiscreteProblem

src/systems/abstractsystem.jl

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ function independent_variable(sys::AbstractSystem)
148148
end
149149

150150
#Treat the result as a vector of symbols always
151-
function independent_variables(sys::AbstractSystem)
151+
function SymbolicIndexingInterface.independent_variables(sys::AbstractSystem)
152152
systype = typeof(sys)
153153
@warn "Please declare ($systype) as a subtype of `AbstractTimeDependentSystem`, `AbstractTimeIndependentSystem` or `AbstractMultivariateSystem`."
154154
if isdefined(sys, :iv)
@@ -160,9 +160,13 @@ function independent_variables(sys::AbstractSystem)
160160
end
161161
end
162162

163-
independent_variables(sys::AbstractTimeDependentSystem) = [getfield(sys, :iv)]
164-
independent_variables(sys::AbstractTimeIndependentSystem) = []
165-
independent_variables(sys::AbstractMultivariateSystem) = getfield(sys, :ivs)
163+
function SymbolicIndexingInterface.independent_variables(sys::AbstractTimeDependentSystem)
164+
[getfield(sys, :iv)]
165+
end
166+
SymbolicIndexingInterface.independent_variables(sys::AbstractTimeIndependentSystem) = []
167+
function SymbolicIndexingInterface.independent_variables(sys::AbstractMultivariateSystem)
168+
getfield(sys, :ivs)
169+
end
166170

167171
iscomplete(sys::AbstractSystem) = isdefined(sys, :complete) && getfield(sys, :complete)
168172

@@ -462,15 +466,15 @@ function namespace_expr(O, sys, n = nameof(sys))
462466
end
463467
end
464468

465-
function states(sys::AbstractSystem)
469+
function SymbolicIndexingInterface.states(sys::AbstractSystem)
466470
sts = get_states(sys)
467471
systems = get_systems(sys)
468472
unique(isempty(systems) ?
469473
sts :
470474
[sts; reduce(vcat, namespace_variables.(systems))])
471475
end
472476

473-
function parameters(sys::AbstractSystem)
477+
function SymbolicIndexingInterface.parameters(sys::AbstractSystem)
474478
ps = get_ps(sys)
475479
systems = get_systems(sys)
476480
unique(isempty(systems) ? ps : [ps; reduce(vcat, namespace_parameters.(systems))])
@@ -508,7 +512,9 @@ end
508512
states(sys::AbstractSystem, v) = renamespace(sys, v)
509513
parameters(sys::AbstractSystem, v) = toparam(states(sys, v))
510514
for f in [:states, :parameters]
511-
@eval $f(sys::AbstractSystem, vs::AbstractArray) = map(v -> $f(sys, v), vs)
515+
@eval function $f(sys::AbstractSystem, vs::AbstractArray)
516+
map(v -> $f(sys, v), vs)
517+
end
512518
end
513519

514520
flatten(sys::AbstractSystem, args...) = sys
@@ -572,6 +578,22 @@ function time_varying_as_func(x, sys::AbstractTimeDependentSystem)
572578
return x
573579
end
574580

581+
SymbolicIndexingInterface.is_indep_sym(sys::AbstractSystem, sym) = isequal(sym, get_iv(sys))
582+
583+
function SymbolicIndexingInterface.state_sym_to_index(sys::AbstractSystem, sym)
584+
findfirst(isequal(sym), SymbolicIndexingInterface.states(sys))
585+
end
586+
function SymbolicIndexingInterface.is_state_sym(sys::AbstractSystem, sym)
587+
!isnothing(SymbolicIndexingInterface.state_sym_to_index(sys, sym))
588+
end
589+
590+
function SymbolicIndexingInterface.param_sym_to_index(sys::AbstractSystem, sym)
591+
findfirst(isequal(sym), SymbolicIndexingInterface.parameters(sys))
592+
end
593+
function SymbolicIndexingInterface.is_param_sym(sys::AbstractSystem, sym)
594+
!isnothing(SymbolicIndexingInterface.param_sym_to_index(sys, sym))
595+
end
596+
575597
struct AbstractSysToExpr
576598
sys::AbstractSystem
577599
states::Vector

src/systems/jumps/jumpsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ function DiscreteProblemExpr(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothing
354354
u0 = $u0
355355
p = $p
356356
tspan = $tspan
357-
df = DiscreteFunction{true, true}(f, syms = $(Symbol.(states(sys))),
357+
df = DiscreteFunction{true, true}(f; syms = $(Symbol.(states(sys))),
358358
indepsym = $(Symbol(get_iv(sys))),
359359
paramsyms = $(Symbol.(parameters(sys))))
360360
DiscreteProblem(df, u0, tspan, p)

test/odesystem.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -930,9 +930,9 @@ let
930930
# TODO: maybe do not emit x_t
931931
sys4s = structural_simplify(sys4)
932932
prob = ODAEProblem(sys4s, [x => 1.0, D(x) => 1.0], (0, 1.0))
933-
@test string.(prob.f.syms) == ["x(t)", "xˍt(t)"]
934-
@test string.(prob.f.paramsyms) == ["pp"]
935-
@test string(prob.f.indepsym) == "t"
933+
@test string.(states(prob.f.sys)) == ["x(t)", "xˍt(t)"]
934+
@test string.(parameters(prob.f.sys)) == ["pp"]
935+
@test string.(independent_variables(prob.f.sys)) == ["t"]
936936
end
937937

938938
let

0 commit comments

Comments
 (0)