Skip to content

Commit fba8b2e

Browse files
authored
Merge pull request #1819 from SciML/myb/md
Reorder metadata and pass kwarg
2 parents 2c0b42b + 6478bec commit fba8b2e

File tree

10 files changed

+71
-35
lines changed

10 files changed

+71
-35
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ NonlinearSolve = "0.3.8"
7070
RecursiveArrayTools = "2.3"
7171
Reexport = "0.2, 1"
7272
RuntimeGeneratedFunctions = "0.4.3, 0.5"
73-
SciMLBase = "1.54"
73+
SciMLBase = "1.56.1"
7474
Setfield = "0.7, 0.8, 1"
7575
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
7676
StaticArrays = "0.10, 0.11, 0.12, 1.0"

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ module ModelingToolkit
55
using DocStringExtensions
66
using AbstractTrees
77
using DiffEqBase, SciMLBase, ForwardDiff, Reexport
8+
using SciMLBase: StandardODEProblem, StandardNonlinearProblem
89
using Distributed
910
using StaticArrays, LinearAlgebra, SparseArrays, LabelledArrays
1011
using InteractiveUtils

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -669,10 +669,12 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map, t
669669
check_length, kwargs...)
670670
cbs = process_events(sys; callback, has_difference, kwargs...)
671671
kwargs = filter_kwargs(kwargs)
672+
pt = something(get_metadata(sys), StandardODEProblem())
673+
672674
if cbs === nothing
673-
ODEProblem{iip}(f, u0, tspan, p; kwargs...)
675+
ODEProblem{iip}(f, u0, tspan, p, pt; kwargs...)
674676
else
675-
ODEProblem{iip}(f, u0, tspan, p; callback = cbs, kwargs...)
677+
ODEProblem{iip}(f, u0, tspan, p, pt; callback = cbs, kwargs...)
676678
end
677679
end
678680
get_callback(prob::ODEProblem) = prob.kwargs[:callback]

src/systems/diffeqs/odesystem.jl

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,6 @@ struct ODESystem <: AbstractODESystem
8989
"""
9090
connector_type::Any
9191
"""
92-
connections: connections in a system
93-
"""
94-
connections::Any
95-
"""
9692
preface: inject assignment statements before the evaluation of the RHS function.
9793
"""
9894
preface::Any
@@ -108,23 +104,23 @@ struct ODESystem <: AbstractODESystem
108104
"""
109105
discrete_events::Vector{SymbolicDiscreteCallback}
110106
"""
107+
metadata: metadata for the system, to be used by downstream packages.
108+
"""
109+
metadata::Any
110+
"""
111111
tearing_state: cache for intermediate tearing state
112112
"""
113113
tearing_state::Any
114114
"""
115115
substitutions: substitutions generated by tearing.
116116
"""
117117
substitutions::Any
118-
"""
119-
metadata: metadata for the system, to be used by downstream packages.
120-
"""
121-
metadata::Any
122118

123119
function ODESystem(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad,
124120
jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults,
125-
torn_matching, connector_type, connections, preface, cevents,
126-
devents, tearing_state = nothing, substitutions = nothing,
127-
metadata = nothing;
121+
torn_matching, connector_type, preface, cevents,
122+
devents, metadata = nothing, tearing_state = nothing,
123+
substitutions = nothing;
128124
checks::Union{Bool, Int} = true)
129125
if checks == true || (checks & CheckComponents) > 0
130126
check_variables(dvs, iv)
@@ -137,8 +133,8 @@ struct ODESystem <: AbstractODESystem
137133
end
138134
new(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac,
139135
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, torn_matching,
140-
connector_type, connections, preface, cevents, devents, tearing_state,
141-
substitutions, metadata)
136+
connector_type, preface, cevents, devents, metadata, tearing_state,
137+
substitutions)
142138
end
143139
end
144140

@@ -191,7 +187,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
191187
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)
192188
ODESystem(deqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac,
193189
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, nothing,
194-
connector_type, nothing, preface, cont_callbacks, disc_callbacks,
190+
connector_type, preface, cont_callbacks, disc_callbacks,
195191
metadata, checks = checks)
196192
end
197193

src/systems/discrete_system/discrete_system.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,22 +60,22 @@ struct DiscreteSystem <: AbstractTimeDependentSystem
6060
"""
6161
connector_type::Any
6262
"""
63+
metadata: metadata for the system, to be used by downstream packages.
64+
"""
65+
metadata::Any
66+
"""
6367
tearing_state: cache for intermediate tearing state
6468
"""
6569
tearing_state::Any
6670
"""
6771
substitutions: substitutions generated by tearing.
6872
"""
6973
substitutions::Any
70-
"""
71-
metadata: metadata for the system, to be used by downstream packages.
72-
"""
73-
metadata::Any
7474

7575
function DiscreteSystem(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name,
7676
systems, defaults, preface, connector_type,
77-
tearing_state = nothing, substitutions = nothing,
78-
metadata = nothing;
77+
metadata = nothing,
78+
tearing_state = nothing, substitutions = nothing;
7979
checks::Union{Bool, Int} = true)
8080
if checks == true || (checks & CheckComponents) > 0
8181
check_variables(dvs, iv)
@@ -85,7 +85,7 @@ struct DiscreteSystem <: AbstractTimeDependentSystem
8585
all_dimensionless([dvs; ps; iv; ctrls]) || check_units(discreteEqs)
8686
end
8787
new(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, defaults,
88-
preface, connector_type, tearing_state, substitutions, metadata)
88+
preface, connector_type, metadata, tearing_state, substitutions)
8989
end
9090
end
9191

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ struct NonlinearSystem <: AbstractTimeIndependentSystem
5151
"""
5252
connector_type::Any
5353
"""
54-
connections: connections in a system
54+
metadata: metadata for the system, to be used by downstream packages.
5555
"""
56-
connections::Any
56+
metadata::Any
5757
"""
5858
tearing_state: cache for intermediate tearing state
5959
"""
@@ -62,20 +62,16 @@ struct NonlinearSystem <: AbstractTimeIndependentSystem
6262
substitutions: substitutions generated by tearing.
6363
"""
6464
substitutions::Any
65-
"""
66-
metadata: metadata for the system, to be used by downstream packages.
67-
"""
68-
metadata::Any
6965

7066
function NonlinearSystem(eqs, states, ps, var_to_name, observed, jac, name, systems,
71-
defaults, connector_type, connections, tearing_state = nothing,
72-
substitutions = nothing, metadata = nothing;
67+
defaults, connector_type, metadata = nothing,
68+
tearing_state = nothing, substitutions = nothing;
7369
checks::Union{Bool, Int} = true)
7470
if checks == true || (checks & CheckUnits) > 0
7571
all_dimensionless([states; ps]) || check_units(eqs)
7672
end
7773
new(eqs, states, ps, var_to_name, observed, jac, name, systems, defaults,
78-
connector_type, connections, tearing_state, substitutions, metadata)
74+
connector_type, metadata, tearing_state, substitutions)
7975
end
8076
end
8177

@@ -125,7 +121,7 @@ function NonlinearSystem(eqs, states, ps;
125121
isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed))
126122

127123
NonlinearSystem(eqs, states, ps, var_to_name, observed, jac, name, systems, defaults,
128-
connector_type, nothing, metadata, checks = checks)
124+
connector_type, metadata, checks = checks)
129125
end
130126

131127
function calculate_jacobian(sys::NonlinearSystem; sparse = false, simplify = false)
@@ -344,7 +340,8 @@ function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem, u0map,
344340
check_length = true, kwargs...) where {iip}
345341
f, u0, p = process_NonlinearProblem(NonlinearFunction{iip}, sys, u0map, parammap;
346342
check_length, kwargs...)
347-
NonlinearProblem{iip}(f, u0, p; kwargs...)
343+
pt = something(get_metadata(sys), StandardNonlinearProblem())
344+
NonlinearProblem{iip}(f, u0, p, pt; kwargs...)
348345
end
349346

350347
"""

test/discretesystem.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
- https://en.wikipedia.org/wiki/Compartmental_models_in_epidemiology#Deterministic_versus_stochastic_epidemic_models
55
=#
66
using ModelingToolkit, Test
7+
using ModelingToolkit: get_metadata
78

89
@inline function rate_to_proportion(r, t)
910
1 - exp(-r * t)
@@ -179,3 +180,10 @@ RHS2 = RHS
179180
sol = solve(prob, FunctionMap(); dt = dt)
180181
@test c[1] + 1 == length(sol)
181182
end
183+
184+
@parameters t
185+
@variables x(t) y(t)
186+
D = Difference(t; dt = 0.1)
187+
testdict = Dict([:test => 1])
188+
@named sys = DiscreteSystem([D(x) ~ 1.0]; metadata = testdict)
189+
@test get_metadata(sys) == testdict

test/nonlinearsystem.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using ModelingToolkit, StaticArrays, LinearAlgebra
2+
using ModelingToolkit: get_metadata
23
using DiffEqBase, SparseArrays
34
using Test
45
using NonlinearSolve
@@ -202,3 +203,11 @@ let
202203

203204
@test sol[u] ones(4)
204205
end
206+
207+
@variables x(t)
208+
@parameters a
209+
eqs = [0 ~ a * x]
210+
211+
testdict = Dict([:test => 1])
212+
@named sys = NonlinearSystem(eqs, [x], [a], metadata = testdict)
213+
@test get_metadata(sys) == testdict

test/odesystem.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using ModelingToolkit, StaticArrays, LinearAlgebra
2+
using ModelingToolkit: get_metadata
23
using OrdinaryDiffEq, Sundials
34
using DiffEqBase, SparseArrays
45
using StaticArrays
@@ -878,3 +879,14 @@ let
878879
∂t(P) ~ -80.0sin(Q)]
879880
@test_throws ArgumentError @named sys = ODESystem(eqs)
880881
end
882+
883+
@parameters C L R
884+
@variables t q(t) p(t) F(t)
885+
D = Differential(t)
886+
887+
eqs = [D(q) ~ -p / L - F
888+
D(p) ~ q / C
889+
0 ~ q / C - R * F]
890+
testdict = Dict([:name => "test"])
891+
@named sys = ODESystem(eqs, t, metadata = testdict)
892+
@test get_metadata(sys) == testdict

test/optimizationsystem.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using ModelingToolkit, SparseArrays, Test, Optimization, OptimizationOptimJL,
22
OptimizationMOI, Ipopt, AmplNLWriter, Ipopt_jll
3+
using ModelingToolkit: get_metadata
34

45
@variables x y
56
@parameters a b
@@ -174,3 +175,13 @@ end
174175
sol = solve(prob, Ipopt.Optimizer())
175176
@test sol.minimum < 1.0
176177
end
178+
179+
@variables x
180+
o1 = (x - 1)^2
181+
c1 = [
182+
x ~ 1,
183+
]
184+
testdict = Dict(["test" => 1])
185+
sys1 = OptimizationSystem(o1, [x], [], name = :sys1, constraints = c1,
186+
metadata = testdict)
187+
@test get_metadata(sys1) == testdict

0 commit comments

Comments
 (0)