Skip to content

Commit 93cab23

Browse files
xtalaxYingboMa
andauthored
[Awaiting Review] Add metadata field to systems (#1768)
Co-authored-by: Yingbo Ma <[email protected]>
1 parent 87f0c15 commit 93cab23

File tree

9 files changed

+69
-26
lines changed

9 files changed

+69
-26
lines changed

docs/src/basics/AbstractSystem.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ Optionally, a system could have:
6060
- `get_defaults(sys)`: A `Dict` that maps variables into their default values.
6161
- `independent_variables(sys)`: The independent variables of a system.
6262
- `get_noiseeqs(sys)`: Noise equations of the current-level system.
63+
- `get_metadata(sys)`: Any metadata about the system or its origin to be used by downstream packages.
6364

6465
Note that if you know a system is an `AbstractTimeDependentSystem` you could use `get_iv` to get the
6566
unique independent variable directly, rather than using `independent_variables(sys)[1]`, which is clunky and may cause problems if `sys` is an `AbstractMultivariateSystem` because there may be more than one independent variable. `AbstractTimeIndependentSystem`s do not have a method `get_iv`, and `independent_variables(sys)` will return a size-zero result for such. For an `AbstractMultivariateSystem`, `get_ivs` is equivalent.

src/systems/abstractsystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,8 @@ for prop in [:eqs
192192
:preface
193193
:torn_matching
194194
:tearing_state
195-
:substitutions]
195+
:substitutions
196+
:metadata]
196197
fname1 = Symbol(:get_, prop)
197198
fname2 = Symbol(:has_, prop)
198199
@eval begin

src/systems/diffeqs/odesystem.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,16 @@ struct ODESystem <: AbstractODESystem
115115
substitutions: substitutions generated by tearing.
116116
"""
117117
substitutions::Any
118+
"""
119+
metadata: metadata for the system, to be used by downstream packages.
120+
"""
121+
metadata::Any
118122

119123
function ODESystem(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad,
120124
jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults,
121125
torn_matching, connector_type, connections, preface, cevents,
122-
devents, tearing_state = nothing, substitutions = nothing;
126+
devents, tearing_state = nothing, substitutions = nothing,
127+
metadata = nothing;
123128
checks::Union{Bool, Int} = true)
124129
if checks == true || (checks & CheckComponents) > 0
125130
check_variables(dvs, iv)
@@ -133,7 +138,7 @@ struct ODESystem <: AbstractODESystem
133138
new(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac,
134139
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, torn_matching,
135140
connector_type, connections, preface, cevents, devents, tearing_state,
136-
substitutions)
141+
substitutions, metadata)
137142
end
138143
end
139144

@@ -149,7 +154,8 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
149154
preface = nothing,
150155
continuous_events = nothing,
151156
discrete_events = nothing,
152-
checks = true)
157+
checks = true,
158+
metadata = nothing)
153159
name === nothing &&
154160
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
155161
deqs = scalarize(deqs)
@@ -186,7 +192,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
186192
ODESystem(deqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac,
187193
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, nothing,
188194
connector_type, nothing, preface, cont_callbacks, disc_callbacks,
189-
checks = checks)
195+
metadata, checks = checks)
190196
end
191197

192198
function ODESystem(eqs, iv = nothing; kwargs...)

src/systems/diffeqs/sdesystem.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,14 @@ struct SDESystem <: AbstractODESystem
9696
true at the end of an integration step.
9797
"""
9898
discrete_events::Vector{SymbolicDiscreteCallback}
99-
99+
"""
100+
metadata: metadata for the system, to be used by downstream packages.
101+
"""
102+
metadata::Any
100103
function SDESystem(deqs, neqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac,
101104
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type,
102-
cevents, devents; checks::Union{Bool, Int} = true)
105+
cevents, devents, metadata = nothing;
106+
checks::Union{Bool, Int} = true)
103107
if checks == true || (checks & CheckComponents) > 0
104108
check_variables(dvs, iv)
105109
check_parameters(ps, iv)
@@ -110,7 +114,8 @@ struct SDESystem <: AbstractODESystem
110114
all_dimensionless([dvs; ps; iv]) || check_units(deqs, neqs)
111115
end
112116
new(deqs, neqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac,
113-
Wfact, Wfact_t, name, systems, defaults, connector_type, cevents, devents)
117+
Wfact, Wfact_t, name, systems, defaults, connector_type, cevents, devents,
118+
metadata)
114119
end
115120
end
116121

@@ -125,7 +130,8 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
125130
connector_type = nothing,
126131
checks = true,
127132
continuous_events = nothing,
128-
discrete_events = nothing)
133+
discrete_events = nothing,
134+
metadata = nothing)
129135
name === nothing &&
130136
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
131137
deqs = scalarize(deqs)
@@ -160,7 +166,7 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
160166

161167
SDESystem(deqs, neqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac,
162168
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type,
163-
cont_callbacks, disc_callbacks; checks = checks)
169+
cont_callbacks, disc_callbacks, metadata; checks = checks)
164170
end
165171

166172
function SDESystem(sys::ODESystem, neqs; kwargs...)

src/systems/discrete_system/discrete_system.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,15 @@ struct DiscreteSystem <: AbstractTimeDependentSystem
6767
substitutions: substitutions generated by tearing.
6868
"""
6969
substitutions::Any
70+
"""
71+
metadata: metadata for the system, to be used by downstream packages.
72+
"""
73+
metadata::Any
7074

7175
function DiscreteSystem(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name,
7276
systems, defaults, preface, connector_type,
73-
tearing_state = nothing, substitutions = nothing;
77+
tearing_state = nothing, substitutions = nothing,
78+
metadata = nothing;
7479
checks::Union{Bool, Int} = true)
7580
if checks == true || (checks & CheckComponents) > 0
7681
check_variables(dvs, iv)
@@ -80,7 +85,7 @@ struct DiscreteSystem <: AbstractTimeDependentSystem
8085
all_dimensionless([dvs; ps; iv; ctrls]) || check_units(discreteEqs)
8186
end
8287
new(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, defaults,
83-
preface, connector_type, tearing_state, substitutions)
88+
preface, connector_type, tearing_state, substitutions, metadata)
8489
end
8590
end
8691

@@ -99,6 +104,7 @@ function DiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps;
99104
defaults = _merge(Dict(default_u0), Dict(default_p)),
100105
preface = nothing,
101106
connector_type = nothing,
107+
metadata = nothing,
102108
kwargs...)
103109
name === nothing &&
104110
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
@@ -125,7 +131,7 @@ function DiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps;
125131
throw(ArgumentError("System names must be unique."))
126132
end
127133
DiscreteSystem(eqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, name, systems,
128-
defaults, preface, connector_type, kwargs...)
134+
defaults, preface, connector_type, metadata, kwargs...)
129135
end
130136

131137
function DiscreteSystem(eqs, iv = nothing; kwargs...)

src/systems/jumps/jumpsystem.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,13 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem
8383
state value or parameter.*
8484
"""
8585
discrete_events::Vector{SymbolicDiscreteCallback}
86-
86+
"""
87+
metadata: metadata for the system, to be used by downstream packages.
88+
"""
89+
metadata::Any
8790
function JumpSystem{U}(ap::U, iv, states, ps, var_to_name, observed, name, systems,
88-
defaults, connector_type, devents;
91+
defaults, connector_type, devents,
92+
metadata = nothing;
8993
checks::Union{Bool, Int} = true) where {U <: ArrayPartition}
9094
if checks == true || (checks & CheckComponents) > 0
9195
check_variables(states, iv)
@@ -95,7 +99,7 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem
9599
all_dimensionless([states; ps; iv]) || check_units(ap, iv)
96100
end
97101
new{U}(ap, iv, states, ps, var_to_name, observed, name, systems, defaults,
98-
connector_type, devents)
102+
connector_type, devents, metadata)
99103
end
100104
end
101105

@@ -110,6 +114,7 @@ function JumpSystem(eqs, iv, states, ps;
110114
checks = true,
111115
continuous_events = nothing,
112116
discrete_events = nothing,
117+
metadata = nothing,
113118
kwargs...)
114119
name === nothing &&
115120
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
@@ -147,7 +152,8 @@ function JumpSystem(eqs, iv, states, ps;
147152
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)
148153

149154
JumpSystem{typeof(ap)}(ap, value(iv), states, ps, var_to_name, observed, name, systems,
150-
defaults, connector_type, disc_callbacks; checks = checks)
155+
defaults, connector_type, disc_callbacks, metadata,
156+
checks = checks)
151157
end
152158

153159
function generate_rate_function(js::JumpSystem, rate)

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,20 @@ struct NonlinearSystem <: AbstractTimeIndependentSystem
6262
substitutions: substitutions generated by tearing.
6363
"""
6464
substitutions::Any
65+
"""
66+
metadata: metadata for the system, to be used by downstream packages.
67+
"""
68+
metadata::Any
6569

6670
function NonlinearSystem(eqs, states, ps, var_to_name, observed, jac, name, systems,
6771
defaults, connector_type, connections, tearing_state = nothing,
68-
substitutions = nothing; checks::Union{Bool, Int} = true)
72+
substitutions = nothing, metadata = nothing;
73+
checks::Union{Bool, Int} = true)
6974
if checks == true || (checks & CheckUnits) > 0
7075
all_dimensionless([states; ps]) || check_units(eqs)
7176
end
7277
new(eqs, states, ps, var_to_name, observed, jac, name, systems, defaults,
73-
connector_type, connections, tearing_state, substitutions)
78+
connector_type, connections, tearing_state, substitutions, metadata)
7479
end
7580
end
7681

@@ -84,7 +89,8 @@ function NonlinearSystem(eqs, states, ps;
8489
connector_type = nothing,
8590
continuous_events = nothing, # this argument is only required for ODESystems, but is added here for the constructor to accept it without error
8691
discrete_events = nothing, # this argument is only required for ODESystems, but is added here for the constructor to accept it without error
87-
checks = true)
92+
checks = true,
93+
metadata = nothing)
8894
continuous_events === nothing || isempty(continuous_events) ||
8995
throw(ArgumentError("NonlinearSystem does not accept `continuous_events`, you provided $continuous_events"))
9096
discrete_events === nothing || isempty(discrete_events) ||
@@ -119,7 +125,7 @@ function NonlinearSystem(eqs, states, ps;
119125
isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed))
120126

121127
NonlinearSystem(eqs, states, ps, var_to_name, observed, jac, name, systems, defaults,
122-
connector_type, nothing, checks = checks)
128+
connector_type, nothing, metadata, checks = checks)
123129
end
124130

125131
function calculate_jacobian(sys::NonlinearSystem; sparse = false, simplify = false)

src/systems/optimization/optimizationsystem.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,20 @@ struct OptimizationSystem <: AbstractTimeIndependentSystem
4040
parameters are not supplied in `ODEProblem`.
4141
"""
4242
defaults::Dict
43+
"""
44+
metadata: metadata for the system, to be used by downstream packages.
45+
"""
46+
metadata::Any
4347
function OptimizationSystem(op, states, ps, var_to_name, observed,
44-
constraints, name, systems, defaults;
48+
constraints, name, systems, defaults, metadata = nothing;
4549
checks::Union{Bool, Int} = true)
4650
if checks == true || (checks & CheckUnits) > 0
4751
check_units(op)
4852
check_units(observed)
4953
all_dimensionless([states; ps]) || check_units(constraints)
5054
end
5155
new(op, states, ps, var_to_name, observed,
52-
constraints, name, systems, defaults)
56+
constraints, name, systems, defaults, metadata)
5357
end
5458
end
5559

@@ -61,7 +65,8 @@ function OptimizationSystem(op, states, ps;
6165
defaults = _merge(Dict(default_u0), Dict(default_p)),
6266
name = nothing,
6367
systems = OptimizationSystem[],
64-
checks = true)
68+
checks = true,
69+
metadata = nothing)
6570
name === nothing &&
6671
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
6772
if !(isempty(default_u0) && isempty(default_p))
@@ -83,7 +88,7 @@ function OptimizationSystem(op, states, ps;
8388
OptimizationSystem(value(op), states, ps, var_to_name,
8489
observed,
8590
constraints,
86-
name, systems, defaults; checks = checks)
91+
name, systems, defaults, metadata; checks = checks)
8792
end
8893

8994
function calculate_gradient(sys::OptimizationSystem)

src/systems/pde/pdesystem.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,24 @@ struct PDESystem <: ModelingToolkit.AbstractMultivariateSystem
6464
name: the name of the system
6565
"""
6666
name::Symbol
67+
"""
68+
metadata: metadata for the system, to be used by downstream packages.
69+
"""
70+
metadata::Any
6771
@add_kwonly function PDESystem(eqs, bcs, domain, ivs, dvs,
6872
ps = SciMLBase.NullParameters();
6973
defaults = Dict(),
7074
systems = [],
7175
connector_type = nothing,
76+
metadata = nothing,
7277
checks::Union{Bool, Int} = true,
7378
name)
7479
if checks == true || (checks & CheckUnits) > 0
7580
all_dimensionless([dvs; ivs; ps]) || check_units(eqs)
7681
end
7782
eqs = eqs isa Vector ? eqs : [eqs]
78-
new(eqs, bcs, domain, ivs, dvs, ps, defaults, connector_type, systems, name)
83+
new(eqs, bcs, domain, ivs, dvs, ps, defaults, connector_type, systems, name,
84+
metadata)
7985
end
8086
end
8187

0 commit comments

Comments
 (0)