Skip to content

Commit 518e043

Browse files
Update to use symbolic indexing interface from RecursiveArrayTools
- Don't pass syms to `SciMLFunction`s - Import, reexport and overload interface methods from RecursiveArrayTools - FIXME: Currently unrelated `states`/`parameters` are also overloaded
1 parent 3e8d62a commit 518e043

File tree

11 files changed

+82
-71
lines changed

11 files changed

+82
-71
lines changed

src/ModelingToolkit.jl

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ import FunctionWrappersWrappers
3333
RuntimeGeneratedFunctions.init(@__MODULE__)
3434

3535
using RecursiveArrayTools
36+
export independent_variables, states, parameters
37+
# using RecursiveArrayTools
3638

3739
import SymbolicUtils
3840
import SymbolicUtils: istree, arguments, operation, similarterm, promote_symtype,
@@ -96,28 +98,28 @@ abstract type AbstractODESystem <: AbstractTimeDependentSystem end
9698
abstract type AbstractMultivariateSystem <: AbstractSystem end
9799
abstract type AbstractOptimizationSystem <: AbstractTimeIndependentSystem end
98100

99-
"""
100-
$(TYPEDSIGNATURES)
101+
# """
102+
# $(TYPEDSIGNATURES)
101103

102-
Get the set of independent variables for the given system.
103-
"""
104-
function independent_variables end
104+
# Get the set of independent variables for the given system.
105+
# """
106+
# function independent_variables end
105107

106108
function independent_variable end
107109

108-
"""
109-
$(TYPEDSIGNATURES)
110+
# """
111+
# $(TYPEDSIGNATURES)
110112

111-
Get the set of states for the given system.
112-
"""
113-
function states end
113+
# Get the set of states for the given system.
114+
# """
115+
# function states end
114116

115-
"""
116-
$(TYPEDSIGNATURES)
117+
# """
118+
# $(TYPEDSIGNATURES)
117119

118-
Get the set of parameters variables for the given system.
119-
"""
120-
function parameters end
120+
# Get the set of parameters variables for the given system.
121+
# """
122+
# function parameters end
121123

122124
# this has to be included early to deal with depency issues
123125
include("structural_transformation/bareiss.jl")
@@ -203,7 +205,7 @@ export Differential, expand_derivatives, @derivatives
203205
export Equation, ConstrainedEquation
204206
export Term, Sym
205207
export SymScope, LocalScope, ParentScope, DelayParentScope, GlobalScope
206-
export independent_variables, independent_variable, states, parameters, equations, controls,
208+
export independent_variable, equations, controls,
207209
observed, structure, full_equations
208210
export structural_simplify, expand_connections, linearize, linearization_function
209211
export DiscreteSystem, DiscreteProblem

src/structural_transformation/codegen.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -353,9 +353,9 @@ function build_torn_function(sys;
353353
eqs_idxs,
354354
states_idxs) :
355355
nothing,
356-
syms = syms,
357-
paramsyms = Symbol.(parameters(sys)),
358-
indepsym = Symbol(get_iv(sys)),
356+
# syms = syms,
357+
# paramsyms = Symbol.(parameters(sys)),
358+
# indepsym = Symbol(get_iv(sys)),
359359
observed = observedfun,
360360
mass_matrix = mass_matrix,
361361
sys = sys), states

src/systems/abstractsystem.jl

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ function independent_variable(sys::AbstractSystem)
148148
end
149149

150150
#Treat the result as a vector of symbols always
151-
function independent_variables(sys::AbstractSystem)
151+
function RecursiveArrayTools.independent_variables(sys::AbstractSystem)
152152
systype = typeof(sys)
153153
@warn "Please declare ($systype) as a subtype of `AbstractTimeDependentSystem`, `AbstractTimeIndependentSystem` or `AbstractMultivariateSystem`."
154154
if isdefined(sys, :iv)
@@ -160,9 +160,9 @@ function independent_variables(sys::AbstractSystem)
160160
end
161161
end
162162

163-
independent_variables(sys::AbstractTimeDependentSystem) = [getfield(sys, :iv)]
164-
independent_variables(sys::AbstractTimeIndependentSystem) = []
165-
independent_variables(sys::AbstractMultivariateSystem) = getfield(sys, :ivs)
163+
RecursiveArrayTools.independent_variables(sys::AbstractTimeDependentSystem) = [getfield(sys, :iv)]
164+
RecursiveArrayTools.independent_variables(sys::AbstractTimeIndependentSystem) = []
165+
RecursiveArrayTools.independent_variables(sys::AbstractMultivariateSystem) = getfield(sys, :ivs)
166166

167167
iscomplete(sys::AbstractSystem) = isdefined(sys, :complete) && getfield(sys, :complete)
168168

@@ -462,15 +462,15 @@ function namespace_expr(O, sys, n = nameof(sys))
462462
end
463463
end
464464

465-
function states(sys::AbstractSystem)
465+
function RecursiveArrayTools.states(sys::AbstractSystem)
466466
sts = get_states(sys)
467467
systems = get_systems(sys)
468468
unique(isempty(systems) ?
469469
sts :
470470
[sts; reduce(vcat, namespace_variables.(systems))])
471471
end
472472

473-
function parameters(sys::AbstractSystem)
473+
function RecursiveArrayTools.parameters(sys::AbstractSystem)
474474
ps = get_ps(sys)
475475
systems = get_systems(sys)
476476
unique(isempty(systems) ? ps : [ps; reduce(vcat, namespace_parameters.(systems))])
@@ -505,10 +505,10 @@ function defaults(sys::AbstractSystem)
505505
isempty(systems) ? defs : mapfoldr(namespace_defaults, merge, systems; init = defs)
506506
end
507507

508-
states(sys::AbstractSystem, v) = renamespace(sys, v)
509-
parameters(sys::AbstractSystem, v) = toparam(states(sys, v))
508+
RecursiveArrayTools.states(sys::AbstractSystem, v) = renamespace(sys, v)
509+
RecursiveArrayTools.parameters(sys::AbstractSystem, v) = toparam(states(sys, v))
510510
for f in [:states, :parameters]
511-
@eval $f(sys::AbstractSystem, vs::AbstractArray) = map(v -> $f(sys, v), vs)
511+
@eval RecursiveArrayTools.$f(sys::AbstractSystem, vs::AbstractArray) = map(v -> $f(sys, v), vs)
512512
end
513513

514514
flatten(sys::AbstractSystem, args...) = sys
@@ -572,11 +572,13 @@ function time_varying_as_func(x, sys::AbstractTimeDependentSystem)
572572
return x
573573
end
574574

575-
is_state_sym(sys::AbstractSystem, sym) = sym in states(sys)
576-
state_sym_to_index(sys::AbstractSystem, sym) = findfirst(isequal(sym), states(sys))
575+
RecursiveArrayTools.is_indep_sym(sys::AbstractSystem, sym) = isequal(sym, get_iv(sys))
577576

578-
is_param_sym(sys::AbstractSystem, sym) = sym in parameters(sys)
579-
param_sym_to_index(sys::AbstractSystem, sym) = findfirst(isequal(sym), parameters(sys))
577+
RecursiveArrayTools.state_sym_to_index(sys::AbstractSystem, sym) = findfirst(isequal(sym), states(sys))
578+
RecursiveArrayTools.is_state_sym(sys::AbstractSystem, sym) = !isnothing(RecursiveArrayTools.state_sym_to_index(sys, sym))
579+
580+
RecursiveArrayTools.param_sym_to_index(sys::AbstractSystem, sym) = findfirst(isequal(sym), parameters(sys))
581+
RecursiveArrayTools.is_param_sym(sys::AbstractSystem, sym) = !isnothing(RecursiveArrayTools.param_sym_to_index(sys, sym))
580582

581583
struct AbstractSysToExpr
582584
sys::AbstractSystem

src/systems/callbacks.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ FunctionalAffect(; f, sts, pars, ctx = nothing) = FunctionalAffect(f, sts, pars,
3535

3636
func(f::FunctionalAffect) = f.f
3737
context(a::FunctionalAffect) = a.ctx
38-
parameters(a::FunctionalAffect) = a.pars
38+
RecursiveArrayTools.parameters(a::FunctionalAffect) = a.pars
3939
parameters_syms(a::FunctionalAffect) = a.pars_syms
40-
states(a::FunctionalAffect) = a.sts
40+
RecursiveArrayTools.states(a::FunctionalAffect) = a.sts
4141
states_syms(a::FunctionalAffect) = a.sts_syms
4242

4343
function Base.:(==)(a1::FunctionalAffect, a2::FunctionalAffect)

src/systems/connectors.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ function Base.:(==)(l1::ConnectionElement, l2::ConnectionElement)
161161
nameof(l1.sys) == nameof(l2.sys) && isequal(l1.v, l2.v) && l1.isouter == l2.isouter
162162
end
163163
namespaced_var(l::ConnectionElement) = states(l, l.v)
164-
states(l::ConnectionElement, v) = states(copy(l.sys), v)
164+
RecursiveArrayTools.states(l::ConnectionElement, v) = states(copy(l.sys), v)
165165

166166
struct ConnectionSet
167167
set::Vector{ConnectionElement} # namespace.sys, var, isouter

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -380,9 +380,9 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
380380
tgrad = _tgrad === nothing ? nothing : _tgrad,
381381
mass_matrix = _M,
382382
jac_prototype = jac_prototype,
383-
syms = Symbol.(states(sys)),
384-
indepsym = Symbol(get_iv(sys)),
385-
paramsyms = Symbol.(ps),
383+
# syms = Symbol.(states(sys)),
384+
# indepsym = Symbol(get_iv(sys)),
385+
# paramsyms = Symbol.(ps),
386386
observed = observedfun,
387387
sparsity = sparsity ? jacobian_sparsity(sys) : nothing)
388388
end
@@ -473,9 +473,9 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
473473
DAEFunction{iip}(f,
474474
sys = sys,
475475
jac = _jac === nothing ? nothing : _jac,
476-
syms = Symbol.(dvs),
477-
indepsym = Symbol(get_iv(sys)),
478-
paramsyms = Symbol.(ps),
476+
# syms = Symbol.(dvs),
477+
# indepsym = Symbol(get_iv(sys)),
478+
# paramsyms = Symbol.(ps),
479479
jac_prototype = jac_prototype,
480480
observed = observedfun)
481481
end
@@ -559,9 +559,9 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
559559
tgrad = $tgradsym,
560560
mass_matrix = M,
561561
jac_prototype = $jp_expr,
562-
syms = $(Symbol.(states(sys))),
563-
indepsym = $(QuoteNode(Symbol(get_iv(sys)))),
564-
paramsyms = $(Symbol.(parameters(sys))),
562+
# syms = $(Symbol.(states(sys))),
563+
# indepsym = $(QuoteNode(Symbol(get_iv(sys)))),
564+
# paramsyms = $(Symbol.(parameters(sys))),
565565
sparsity = $(sparsity ? jacobian_sparsity(sys) : nothing))
566566
end
567567
!linenumbers ? striplines(ex) : ex

src/systems/diffeqs/sdesystem.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -437,9 +437,9 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = states(sys),
437437
Wfact = _Wfact === nothing ? nothing : _Wfact,
438438
Wfact_t = _Wfact_t === nothing ? nothing : _Wfact_t,
439439
mass_matrix = _M,
440-
syms = Symbol.(states(sys)),
441-
indepsym = Symbol(get_iv(sys)),
442-
paramsyms = Symbol.(ps),
440+
# syms = Symbol.(states(sys)),
441+
# indepsym = Symbol(get_iv(sys)),
442+
# paramsyms = Symbol.(ps),
443443
observed = observedfun)
444444
end
445445

@@ -524,9 +524,10 @@ function SDEFunctionExpr{iip}(sys::SDESystem, dvs = states(sys),
524524
Wfact = Wfact,
525525
Wfact_t = Wfact_t,
526526
mass_matrix = M,
527-
syms = $(Symbol.(states(sys))),
528-
indepsym = $(Symbol(get_iv(sys))),
529-
paramsyms = $(Symbol.(parameters(sys))))
527+
# syms = $(Symbol.(states(sys))),
528+
# indepsym = $(Symbol(get_iv(sys))),
529+
# paramsyms = $(Symbol.(parameters(sys))))
530+
)
530531
end
531532
!linenumbers ? striplines(ex) : ex
532533
end

src/systems/discrete_system/discrete_system.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,10 @@ function SciMLBase.DiscreteProblem(sys::DiscreteSystem, u0map = [], tspan = get_
223223
expression_module = eval_module)
224224
f_oop, _ = (@RuntimeGeneratedFunction(eval_module, ex) for ex in f_gen)
225225
f(u, p, iv) = f_oop(u, p, iv)
226-
fd = DiscreteFunction(f; syms = Symbol.(dvs), indepsym = Symbol(iv),
227-
paramsyms = Symbol.(ps), sys = sys)
226+
fd = DiscreteFunction(f; # syms = Symbol.(dvs),
227+
# indepsym = Symbol(iv),
228+
# paramsyms = Symbol.(ps),
229+
sys = sys)
228230
DiscreteProblem(fd, u0, tspan, p; kwargs...)
229231
end
230232

src/systems/jumps/jumpsystem.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -312,9 +312,10 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,
312312
end
313313
end
314314

315-
df = DiscreteFunction{true, true}(f; syms = Symbol.(states(sys)),
316-
indepsym = Symbol(get_iv(sys)),
317-
paramsyms = Symbol.(ps), sys = sys,
315+
df = DiscreteFunction{true, true}(f;
316+
# syms = Symbol.(states(sys)),
317+
# indepsym = Symbol(get_iv(sys)),
318+
# paramsyms = Symbol.(ps), sys = sys,
318319
observed = observedfun)
319320
DiscreteProblem(df, u0, tspan, p; kwargs...)
320321
end
@@ -354,9 +355,11 @@ function DiscreteProblemExpr(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothing
354355
u0 = $u0
355356
p = $p
356357
tspan = $tspan
357-
df = DiscreteFunction{true, true}(f, syms = $(Symbol.(states(sys))),
358-
indepsym = $(Symbol(get_iv(sys))),
359-
paramsyms = $(Symbol.(parameters(sys))))
358+
df = DiscreteFunction{true, true}(f,
359+
# syms = $(Symbol.(states(sys))),
360+
# indepsym = $(Symbol(get_iv(sys))),
361+
# paramsyms = $(Symbol.(parameters(sys)))
362+
)
360363
DiscreteProblem(df, u0, tspan, p)
361364
end
362365
end

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -253,8 +253,8 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = states(sys
253253
jac_prototype = sparse ?
254254
similar(calculate_jacobian(sys, sparse = sparse),
255255
Float64) : nothing,
256-
syms = Symbol.(states(sys)),
257-
paramsyms = Symbol.(parameters(sys)),
256+
# syms = Symbol.(states(sys)),
257+
# paramsyms = Symbol.(parameters(sys)),
258258
observed = observedfun)
259259
end
260260

@@ -300,8 +300,9 @@ function NonlinearFunctionExpr{iip}(sys::NonlinearSystem, dvs = states(sys),
300300
NonlinearFunction{$iip}(f,
301301
jac = jac,
302302
jac_prototype = $jp_expr,
303-
syms = $(Symbol.(states(sys))),
304-
paramsyms = $(Symbol.(parameters(sys))))
303+
# syms = $(Symbol.(states(sys))),
304+
# paramsyms = $(Symbol.(parameters(sys)))
305+
)
305306
end
306307
!linenumbers ? striplines(ex) : ex
307308
end
@@ -330,7 +331,7 @@ function process_NonlinearProblem(constructor, sys::NonlinearSystem, u0map, para
330331

331332
f = constructor(sys, dvs, ps, u0; jac = jac, checkbounds = checkbounds,
332333
linenumbers = linenumbers, parallel = parallel, simplify = simplify,
333-
syms = Symbol.(dvs), paramsyms = Symbol.(ps),
334+
# syms = Symbol.(dvs), paramsyms = Symbol.(ps),
334335
sparse = sparse, eval_expression = eval_expression, kwargs...)
335336
return f, u0, p
336337
end

0 commit comments

Comments
 (0)