Skip to content

Commit 96030e6

Browse files
authored
Merge pull request #2174 from SciML/myb/parse
Simple macro-based frontend
2 parents 884c239 + 897c9fa commit 96030e6

File tree

5 files changed

+320
-7
lines changed

5 files changed

+320
-7
lines changed

Project.toml

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
2828
Latexify = "23fbe1c1-3f47-55db-b15f-69d7ec21a316"
2929
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
3030
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
31+
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
3132
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
3233
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
3334
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
@@ -46,6 +47,12 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
4647
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
4748
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
4849

50+
[weakdeps]
51+
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
52+
53+
[extensions]
54+
MTKDeepDiffsExt = "DeepDiffs"
55+
4956
[compat]
5057
AbstractTrees = "0.3, 0.4"
5158
ArrayInterface = "6, 7"
@@ -67,6 +74,7 @@ JuliaFormatter = "1"
6774
JumpProcesses = "9.1"
6875
LabelledArrays = "1.3"
6976
Latexify = "0.11, 0.12, 0.13, 0.14, 0.15, 0.16"
77+
MLStyle = "0.4.17"
7078
MacroTools = "0.5"
7179
NaNMath = "0.3, 1"
7280
RecursiveArrayTools = "2.3"
@@ -84,9 +92,6 @@ UnPack = "0.1, 1.0"
8492
Unitful = "1.1"
8593
julia = "1.6"
8694

87-
[extensions]
88-
MTKDeepDiffsExt = "DeepDiffs"
89-
9095
[extras]
9196
AmplNLWriter = "7c4d4715-977e-5154-bfe0-e096adeac482"
9297
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
@@ -113,6 +118,3 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
113118

114119
[targets]
115120
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsMTK", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials"]
116-
117-
[weakdeps]
118-
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"

src/ModelingToolkit.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ include("utils.jl")
118118
include("domains.jl")
119119

120120
include("systems/abstractsystem.jl")
121+
include("systems/model_parsing.jl")
121122
include("systems/connectors.jl")
122123
include("systems/callbacks.jl")
123124

@@ -181,7 +182,7 @@ export JumpProblem, DiscreteProblem
181182
export NonlinearSystem, OptimizationSystem, ConstraintsSystem
182183
export alias_elimination, flatten
183184
export connect, @connector, Connection, Flow, Stream, instream
184-
export @component
185+
export @component, @model
185186
export isinput, isoutput, getbounds, hasbounds, isdisturbance, istunable, getdist, hasdist,
186187
tunable_parameters, isirreducible, getdescription, hasdescription, isbinaryvar,
187188
isintegervar

src/systems/model_parsing.jl

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
macro connector(name::Symbol, body)
2+
esc(connector_macro(__module__, name, body))
3+
end
4+
5+
struct Model{F, S}
6+
f::F
7+
structure::S
8+
end
9+
(m::Model)(args...; kw...) = m.f(args...; kw...)
10+
11+
using MLStyle
12+
function connector_macro(mod, name, body)
13+
if !Meta.isexpr(body, :block)
14+
err = """
15+
connector body must be a block! It should be in the form of
16+
```
17+
@connector Pin begin
18+
v(t) = 1
19+
(i(t) = 1), [connect = Flow]
20+
end
21+
```
22+
"""
23+
error(err)
24+
end
25+
vs = Num[]
26+
dict = Dict{Symbol, Any}()
27+
for arg in body.args
28+
arg isa LineNumberNode && continue
29+
push!(vs, Num(parse_variable_def!(dict, mod, arg, :variables)))
30+
end
31+
iv = get(dict, :independent_variable, nothing)
32+
if iv === nothing
33+
error("$name doesn't have a independent variable")
34+
end
35+
quote
36+
$name = $Model((; name) -> begin
37+
var"#___sys___" = $ODESystem($(Equation[]), $iv, $vs, $([]);
38+
name)
39+
$Setfield.@set!(var"#___sys___".connector_type=$connector_type(var"#___sys___"))
40+
end, $dict)
41+
end
42+
end
43+
44+
function parse_variable_def!(dict, mod, arg, varclass)
45+
MLStyle.@match arg begin
46+
::Symbol => generate_var!(dict, arg, varclass)
47+
Expr(:call, a, b) => generate_var!(dict, a, b, varclass)
48+
Expr(:(=), a, b) => begin
49+
var = parse_variable_def!(dict, mod, a, varclass)
50+
def = parse_default(mod, b)
51+
dict[varclass][getname(var)][:default] = def
52+
setdefault(var, def)
53+
end
54+
Expr(:tuple, a, b) => begin
55+
var = parse_variable_def!(dict, mod, a, varclass)
56+
meta = parse_metadata(mod, b)
57+
if (ct = get(meta, VariableConnectType, nothing)) !== nothing
58+
dict[varclass][getname(var)][:connection_type] = nameof(ct)
59+
end
60+
set_var_metadata(var, meta)
61+
end
62+
_ => error("$arg cannot be parsed")
63+
end
64+
end
65+
66+
function generate_var(a, varclass)
67+
var = Symbolics.variable(a)
68+
if varclass == :parameters
69+
var = toparam(var)
70+
end
71+
var
72+
end
73+
function generate_var!(dict, a, varclass)
74+
var = generate_var(a, varclass)
75+
vd = get!(dict, varclass) do
76+
Dict{Symbol, Dict{Symbol, Any}}()
77+
end
78+
vd[a] = Dict{Symbol, Any}()
79+
var
80+
end
81+
function generate_var!(dict, a, b, varclass)
82+
iv = generate_var(b, :variables)
83+
prev_iv = get!(dict, :independent_variable) do
84+
iv
85+
end
86+
@assert isequal(iv, prev_iv)
87+
vd = get!(dict, varclass) do
88+
Dict{Symbol, Dict{Symbol, Any}}()
89+
end
90+
vd[a] = Dict{Symbol, Any}()
91+
var = Symbolics.variable(a, T = SymbolicUtils.FnType{Tuple{Real}, Real})(iv)
92+
if varclass == :parameters
93+
var = toparam(var)
94+
end
95+
var
96+
end
97+
function parse_default(mod, a)
98+
a = Base.remove_linenums!(deepcopy(a))
99+
MLStyle.@match a begin
100+
Expr(:block, a) => get_var(mod, a)
101+
::Symbol => get_var(mod, a)
102+
::Number => a
103+
_ => error("Cannot parse default $a")
104+
end
105+
end
106+
function parse_metadata(mod, a)
107+
MLStyle.@match a begin
108+
Expr(:vect, eles...) => Dict(parse_metadata(mod, e) for e in eles)
109+
Expr(:(=), a, b) => Symbolics.option_to_metadata_type(Val(a)) => get_var(mod, b)
110+
_ => error("Cannot parse metadata $a")
111+
end
112+
end
113+
function set_var_metadata(a, ms)
114+
for (m, v) in ms
115+
a = setmetadata(a, m, v)
116+
end
117+
a
118+
end
119+
function get_var(mod::Module, b)
120+
b isa Symbol ? getproperty(mod, b) : b
121+
end
122+
123+
macro model(name::Symbol, expr)
124+
esc(model_macro(__module__, name, expr))
125+
end
126+
function model_macro(mod, name, expr)
127+
exprs = Expr(:block)
128+
dict = Dict{Symbol, Any}()
129+
comps = Symbol[]
130+
ext = Ref{Any}(nothing)
131+
vs = Symbol[]
132+
ps = Symbol[]
133+
eqs = Expr[]
134+
for arg in expr.args
135+
arg isa LineNumberNode && continue
136+
arg.head == :macrocall || error("$arg is not valid syntax. Expected a macro call.")
137+
parse_model!(exprs.args, comps, ext, eqs, vs, ps, dict, mod, arg)
138+
end
139+
iv = get(dict, :independent_variable, nothing)
140+
if iv === nothing
141+
iv = dict[:independent_variable] = variable(:t)
142+
end
143+
sys = :($ODESystem($Equation[$(eqs...)], $iv, [$(vs...)], [$(ps...)];
144+
systems = [$(comps...)], name))
145+
if ext[] === nothing
146+
push!(exprs.args, sys)
147+
else
148+
push!(exprs.args, :($extend($sys, $(ext[]))))
149+
end
150+
:($name = $Model((; name) -> $exprs, $dict))
151+
end
152+
function parse_model!(exprs, comps, ext, eqs, vs, ps, dict, mod, arg)
153+
mname = arg.args[1]
154+
body = arg.args[end]
155+
if mname == Symbol("@components")
156+
parse_components!(exprs, comps, dict, body)
157+
elseif mname == Symbol("@extend")
158+
parse_extend!(exprs, ext, dict, body)
159+
elseif mname == Symbol("@variables")
160+
parse_variables!(exprs, vs, dict, mod, body, :variables)
161+
elseif mname == Symbol("@parameters")
162+
parse_variables!(exprs, ps, dict, mod, body, :parameters)
163+
elseif mname == Symbol("@equations")
164+
parse_equations!(exprs, eqs, dict, body)
165+
else
166+
error("$mname is not handled.")
167+
end
168+
end
169+
function parse_components!(exprs, cs, dict, body)
170+
expr = Expr(:block)
171+
push!(exprs, expr)
172+
comps = Vector{String}[]
173+
for arg in body.args
174+
arg isa LineNumberNode && continue
175+
MLStyle.@match arg begin
176+
Expr(:(=), a, b) => begin
177+
push!(cs, a)
178+
push!(comps, [String(a), String(b.args[1])])
179+
arg = deepcopy(arg)
180+
b = deepcopy(arg.args[2])
181+
push!(b.args, Expr(:kw, :name, Meta.quot(a)))
182+
arg.args[2] = b
183+
push!(expr.args, arg)
184+
end
185+
_ => error("`@components` only takes assignment expressions. Got $arg")
186+
end
187+
end
188+
dict[:components] = comps
189+
end
190+
function parse_extend!(exprs, ext, dict, body)
191+
expr = Expr(:block)
192+
push!(exprs, expr)
193+
body = deepcopy(body)
194+
MLStyle.@match body begin
195+
Expr(:(=), a, b) => begin
196+
vars = nothing
197+
if Meta.isexpr(b, :(=))
198+
vars = a
199+
if !Meta.isexpr(vars, :tuple)
200+
error("`@extend` destructuring only takes an tuple as LHS. Got $body")
201+
end
202+
a, b = b.args
203+
vars, a, b
204+
end
205+
ext[] = a
206+
push!(b.args, Expr(:kw, :name, Meta.quot(a)))
207+
dict[:extend] = [Symbol.(vars.args), a, b.args[1]]
208+
push!(expr.args, :($a = $b))
209+
if vars !== nothing
210+
push!(expr.args, :(@unpack $vars = $a))
211+
end
212+
end
213+
_ => error("`@extend` only takes an assignment expression. Got $body")
214+
end
215+
end
216+
function parse_variables!(exprs, vs, dict, mod, body, varclass)
217+
expr = Expr(:block)
218+
push!(exprs, expr)
219+
for arg in body.args
220+
arg isa LineNumberNode && continue
221+
vv = parse_variable_def!(dict, mod, arg, varclass)
222+
v = Num(vv)
223+
name = getname(v)
224+
push!(vs, name)
225+
push!(expr.args, :($name = $v))
226+
end
227+
end
228+
function parse_equations!(exprs, eqs, dict, body)
229+
for arg in body.args
230+
arg isa LineNumberNode && continue
231+
push!(eqs, arg)
232+
end
233+
# TODO: does this work with TOML?
234+
dict[:equations] = readable_code.(eqs)
235+
end

test/model_parsing.jl

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
using ModelingToolkit, Test
2+
3+
@connector RealInput begin u(t), [input = true] end
4+
@connector RealOutput begin u(t), [output = true] end
5+
@model Constant begin
6+
@components begin output = RealOutput() end
7+
@parameters begin k, [description = "Constant output value of block"] end
8+
@equations begin output.u ~ k end
9+
end
10+
11+
@variables t
12+
D = Differential(t)
13+
14+
@connector Pin begin
15+
v(t) = 0 # Potential at the pin [V]
16+
i(t), [connect = Flow] # Current flowing into the pin [A]
17+
end
18+
19+
@model OnePort begin
20+
@components begin
21+
p = Pin()
22+
n = Pin()
23+
end
24+
@variables begin
25+
v(t)
26+
i(t)
27+
end
28+
@equations begin
29+
v ~ p.v - n.v
30+
0 ~ p.i + n.i
31+
i ~ p.i
32+
end
33+
end
34+
35+
@model Ground begin
36+
@components begin g = Pin() end
37+
@equations begin g.v ~ 0 end
38+
end
39+
40+
@model Resistor begin
41+
@extend v, i = oneport = OnePort()
42+
@parameters begin R = 1 end
43+
@equations begin v ~ i * R end
44+
end
45+
46+
@model Capacitor begin
47+
@extend v, i = oneport = OnePort()
48+
@parameters begin C = 1 end
49+
@equations begin D(v) ~ i / C end
50+
end
51+
52+
@model Voltage begin
53+
@extend v, i = oneport = OnePort()
54+
@components begin V = RealInput() end
55+
@equations begin v ~ V.u end
56+
end
57+
58+
@model RC begin
59+
@components begin
60+
resistor = Resistor()
61+
capacitor = Capacitor()
62+
source = Voltage()
63+
constant = Constant()
64+
ground = Ground()
65+
end
66+
@equations begin
67+
connect(constant.output, source.V)
68+
connect(source.p, resistor.p)
69+
connect(resistor.n, capacitor.p)
70+
connect(capacitor.n, source.n, ground.g)
71+
end
72+
end
73+
@named rc = RC()
74+
@test length(equations(structural_simplify(rc))) == 1

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ using SafeTestsets, Test
2525
@safetestset "Reduction Test" begin include("reduction.jl") end
2626
@safetestset "ODAEProblem Test" begin include("odaeproblem.jl") end
2727
@safetestset "Components Test" begin include("components.jl") end
28+
@safetestset "Model Parsing Test" begin include("model_parsing.jl") end
2829
@safetestset "print_tree" begin include("print_tree.jl") end
2930
@safetestset "Error Handling" begin include("error_handling.jl") end
3031
@safetestset "StructuralTransformations" begin include("structural_transformation/runtests.jl") end

0 commit comments

Comments
 (0)