Skip to content

Commit f12f472

Browse files
authored
Merge pull request #973 from SciML/ys/connect
Add connectors
2 parents c9cf93d + e753fa4 commit f12f472

File tree

13 files changed

+200
-69
lines changed

13 files changed

+200
-69
lines changed

examples/electrical_components.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,23 @@ using ModelingToolkit, OrdinaryDiffEq
33

44
# Basic electric components
55
@parameters t
6-
function Pin(;name)
6+
@connector function Pin(;name)
77
@variables v(t) i(t)
88
ODESystem(Equation[], t, [v, i], [], name=name, defaults=[v=>1.0, i=>1.0])
99
end
1010

11+
function ModelingToolkit.connect(::Type{Pin}, ps...)
12+
eqs = [
13+
0 ~ sum(p->p.i, ps) # KCL
14+
]
15+
# KVL
16+
for i in 1:length(ps)-1
17+
push!(eqs, ps[i].v ~ ps[i+1].v)
18+
end
19+
20+
return eqs
21+
end
22+
1123
function Ground(;name)
1224
@named g = Pin()
1325
eqs = [g.v ~ 0]
@@ -70,15 +82,3 @@ function Inductor(; name, L = 1.0)
7082
]
7183
ODESystem(eqs, t, [v, i], [L], systems=[p, n], defaults=Dict(L => val), name=name)
7284
end
73-
74-
function connect_pins(ps...)
75-
eqs = [
76-
0 ~ sum(p->p.i, ps) # KCL
77-
]
78-
# KVL
79-
for i in 1:length(ps)-1
80-
push!(eqs, ps[i].v ~ ps[i+1].v)
81-
end
82-
83-
return eqs
84-
end

examples/rc_model.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ V = 1.0
99
@named ground = Ground()
1010

1111
rc_eqs = [
12-
connect_pins(source.p, resistor.p)
13-
connect_pins(resistor.n, capacitor.p)
14-
connect_pins(capacitor.n, source.n, ground.g)
12+
connect(source.p, resistor.p)
13+
connect(resistor.n, capacitor.p)
14+
connect(capacitor.n, source.n, ground.g)
1515
]
1616

1717
@named rc_model = ODESystem(rc_eqs, t, systems=[resistor, capacitor, source, ground])

examples/serial_inductor.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ include("electrical_components.jl")
77
@named ground = Ground()
88

99
eqs = [
10-
connect_pins(source.p, resistor.p)
11-
connect_pins(resistor.n, inductor1.p)
12-
connect_pins(inductor1.n, inductor2.p)
13-
connect_pins(source.n, inductor2.n, ground.g)
10+
connect(source.p, resistor.p)
11+
connect(resistor.n, inductor1.p)
12+
connect(inductor1.n, inductor2.p)
13+
connect(source.n, inductor2.n, ground.g)
1414
]
1515

1616
@named ll_model = ODESystem(eqs, t, systems=[source, resistor, inductor1, inductor2, ground])

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ export SteadyStateProblem, SteadyStateProblemExpr
149149
export JumpProblem, DiscreteProblem
150150
export NonlinearSystem, OptimizationSystem
151151
export ControlSystem
152-
export alias_elimination, flatten
152+
export alias_elimination, flatten, connect, @connector
153153
export ode_order_lowering, liouville_transform
154154
export runge_kutta_discretize
155155
export PDESystem

src/systems/abstractsystem.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ for prop in [
157157
:domain
158158
:depvars
159159
:indvars
160+
:connection_type
160161
]
161162
fname1 = Symbol(:get_, prop)
162163
fname2 = Symbol(:has_, prop)
@@ -579,3 +580,64 @@ function check_eqs_u0(eqs, dvs, u0)
579580
end
580581
return nothing
581582
end
583+
584+
function with_connection_type(expr)
585+
@assert expr isa Expr && (expr.head == :function || (expr.head == :(=) &&
586+
expr.args[1] isa Expr &&
587+
expr.args[1].head == :call))
588+
589+
sig = expr.args[1]
590+
body = expr.args[2]
591+
592+
fname = sig.args[1]
593+
args = sig.args[2:end]
594+
595+
quote
596+
struct $fname
597+
$(gensym()) -> 1 # this removes the default constructor
598+
end
599+
function $fname($(args...))
600+
function f()
601+
$body
602+
end
603+
res = f()
604+
$isdefined(res, :connection_type) ? $Setfield.@set!(res.connection_type = $fname) : res
605+
end
606+
end
607+
end
608+
609+
macro connector(expr)
610+
esc(with_connection_type(expr))
611+
end
612+
613+
promote_connect_rule(::Type{T}, ::Type{S}) where {T, S} = Union{}
614+
promote_connect_rule(::Type{T}, ::Type{T}) where {T} = T
615+
promote_connect_type(t1::Type, t2::Type, ts::Type...) = promote_connect_rule(promote_connect_rule(t1, t2), ts...)
616+
@inline function promote_connect_type(::Type{T}, ::Type{S}) where {T,S}
617+
promote_connect_result(
618+
T,
619+
S,
620+
promote_connect_rule(T,S),
621+
promote_connect_rule(S,T)
622+
)
623+
end
624+
625+
promote_connect_result(::Type, ::Type, ::Type{T}, ::Type{Union{}}) where {T} = T
626+
promote_connect_result(::Type, ::Type, ::Type{Union{}}, ::Type{S}) where {S} = S
627+
promote_connect_result(::Type, ::Type, ::Type{T}, ::Type{T}) where {T} = T
628+
function promote_connect_result(::Type{T}, ::Type{S}, ::Type{P1}, ::Type{P2}) where {T,S,P1,P2}
629+
throw(ArgumentError("connection promotion for $T and $S resulted in $P1 and $P2. " *
630+
"Define promotion only in one direction."))
631+
end
632+
633+
throw_connector_promotion(T, S) = throw(ArgumentError("Don't know how to connect systems of type $S and $T"))
634+
promote_connect_result(::Type{T},::Type{S},::Type{Union{}},::Type{Union{}}) where {T,S} = throw_connector_promotion(T,S)
635+
636+
promote_connect_type(::Type{T}, ::Type{T}) where {T} = T
637+
function promote_connect_type(T, S)
638+
error("Don't know how to connect systems of type $S and $T")
639+
end
640+
641+
function connect(syss...)
642+
connect(promote_connect_type(map(get_connection_type, syss)...), syss...)
643+
end

src/systems/diffeqs/odesystem.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ struct ODESystem <: AbstractODESystem
6969
structure: structural information of the system
7070
"""
7171
structure::Any
72+
"""
73+
type: type of the system
74+
"""
75+
connection_type::Any
7276
end
7377

7478
function ODESystem(
@@ -79,6 +83,7 @@ function ODESystem(
7983
default_u0=Dict(),
8084
default_p=Dict(),
8185
defaults=_merge(Dict(default_u0), Dict(default_p)),
86+
connection_type=nothing,
8287
)
8388
iv′ = value(iv)
8489
dvs′ = value.(dvs)
@@ -98,7 +103,7 @@ function ODESystem(
98103
if length(unique(sysnames)) != length(sysnames)
99104
throw(ArgumentError("System names must be unique."))
100105
end
101-
ODESystem(deqs, iv′, dvs′, ps′, observed, tgrad, jac, Wfact, Wfact_t, name, systems, defaults, nothing)
106+
ODESystem(deqs, iv′, dvs′, ps′, observed, tgrad, jac, Wfact, Wfact_t, name, systems, defaults, nothing, connection_type)
102107
end
103108

104109
iv_from_nested_derivative(x::Term) = operation(x) isa Differential ? iv_from_nested_derivative(arguments(x)[1]) : arguments(x)[1]

src/systems/diffeqs/sdesystem.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ struct SDESystem <: AbstractODESystem
7171
parameters are not supplied in `ODEProblem`.
7272
"""
7373
defaults::Dict
74+
"""
75+
type: type of the system
76+
"""
77+
connection_type::Any
7478
end
7579

7680
function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
@@ -79,7 +83,9 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
7983
default_u0=Dict(),
8084
default_p=Dict(),
8185
defaults=_merge(Dict(default_u0), Dict(default_p)),
82-
name = gensym(:SDESystem))
86+
name = gensym(:SDESystem),
87+
connection_type=nothing,
88+
)
8389
iv′ = value(iv)
8490
dvs′ = value.(dvs)
8591
ps′ = value.(ps)
@@ -94,7 +100,7 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
94100
jac = RefValue{Any}(Matrix{Num}(undef, 0, 0))
95101
Wfact = RefValue(Matrix{Num}(undef, 0, 0))
96102
Wfact_t = RefValue(Matrix{Num}(undef, 0, 0))
97-
SDESystem(deqs, neqs, iv′, dvs′, ps′, observed, tgrad, jac, Wfact, Wfact_t, name, systems, defaults)
103+
SDESystem(deqs, neqs, iv′, dvs′, ps′, observed, tgrad, jac, Wfact, Wfact_t, name, systems, defaults, connection_type)
98104
end
99105

100106
function generate_diffusion_function(sys::SDESystem, dvs = states(sys), ps = parameters(sys); kwargs...)

src/systems/jumps/jumpsystem.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractSystem
4747
parameters are not supplied in `ODEProblem`.
4848
"""
4949
defaults::Dict
50+
"""
51+
type: type of the system
52+
"""
53+
connection_type::Any
5054
end
5155

5256
function JumpSystem(eqs, iv, states, ps;
@@ -55,7 +59,9 @@ function JumpSystem(eqs, iv, states, ps;
5559
default_u0=Dict(),
5660
default_p=Dict(),
5761
defaults=_merge(Dict(default_u0), Dict(default_p)),
58-
name = gensym(:JumpSystem))
62+
name = gensym(:JumpSystem),
63+
connection_type=nothing,
64+
)
5965

6066
ap = ArrayPartition(MassActionJump[], ConstantRateJump[], VariableRateJump[])
6167
for eq in eqs
@@ -75,7 +81,7 @@ function JumpSystem(eqs, iv, states, ps;
7581
defaults = todict(defaults)
7682
defaults = Dict(value(k) => value(v) for (k, v) in pairs(defaults))
7783

78-
JumpSystem{typeof(ap)}(ap, value(iv), value.(states), value.(ps), observed, name, systems, defaults)
84+
JumpSystem{typeof(ap)}(ap, value(iv), value.(states), value.(ps), observed, name, systems, defaults, connection_type)
7985
end
8086

8187
function generate_rate_function(js, rate)

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,21 +43,27 @@ struct NonlinearSystem <: AbstractSystem
4343
structure: structural information of the system
4444
"""
4545
structure::Any
46+
"""
47+
type: type of the system
48+
"""
49+
connection_type::Any
4650
end
4751

4852
function NonlinearSystem(eqs, states, ps;
49-
observed = [],
50-
name = gensym(:NonlinearSystem),
53+
observed=[],
54+
name=gensym(:NonlinearSystem),
5155
default_u0=Dict(),
5256
default_p=Dict(),
5357
defaults=_merge(Dict(default_u0), Dict(default_p)),
54-
systems = NonlinearSystem[])
58+
systems=NonlinearSystem[],
59+
connection_type=nothing,
60+
)
5561
if !(isempty(default_u0) && isempty(default_p))
5662
Base.depwarn("`default_u0` and `default_p` are deprecated. Use `defaults` instead.", :NonlinearSystem, force=true)
5763
end
5864
defaults = todict(defaults)
5965
defaults = Dict(value(k) => value(v) for (k, v) in pairs(defaults))
60-
NonlinearSystem(eqs, value.(states), value.(ps), observed, name, systems, defaults, nothing)
66+
NonlinearSystem(eqs, value.(states), value.(ps), observed, name, systems, defaults, nothing, connection_type)
6167
end
6268

6369
function calculate_jacobian(sys::NonlinearSystem;sparse=false,simplify=false)

src/systems/pde/pdesystem.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,16 @@ struct PDESystem <: ModelingToolkit.AbstractSystem
5252
parameters are not supplied in `ODEProblem`.
5353
"""
5454
defaults::Dict
55-
@add_kwonly function PDESystem(eqs, bcs, domain, indvars, depvars, ps = SciMLBase.NullParameters(), defaults = Dict())
56-
new(eqs, bcs, domain, indvars, depvars, ps, defaults)
55+
"""
56+
type: type of the system
57+
"""
58+
connection_type::Any
59+
@add_kwonly function PDESystem(eqs, bcs, domain, indvars, depvars,
60+
ps=SciMLBase.NullParameters();
61+
defaults=Dict(),
62+
connection_type=nothing,
63+
)
64+
new(eqs, bcs, domain, indvars, depvars, ps, defaults, connection_type)
5765
end
5866
end
5967

0 commit comments

Comments
 (0)