Skip to content

Commit d74ab35

Browse files
authored
Merge pull request #1114 from SciML/myb/imp
Some usability improvements
2 parents d755ac3 + a16e873 commit d74ab35

File tree

5 files changed

+99
-31
lines changed

5 files changed

+99
-31
lines changed

docs/src/tutorials/acausal_components.md

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,11 @@ rc_eqs = [
9595
connect(capacitor.n, source.n, ground.g)
9696
]
9797

98-
@named rc_model = ODESystem(rc_eqs, t,
99-
systems=[resistor, capacitor, source, ground])
98+
@named rc_model = compose(ODESystem(rc_eqs, t),
99+
[resistor, capacitor, source, ground])
100100
sys = structural_simplify(rc_model)
101101
u0 = [
102102
capacitor.v => 0.0
103-
capacitor.p.i => 0.0
104103
]
105104
prob = ODAEProblem(sys, u0, (0, 10.0))
106105
sol = solve(prob, Tsit5())
@@ -289,15 +288,15 @@ rc_eqs = [
289288
Finally we build our four component model with these connection rules:
290289
291290
```julia
292-
@named rc_model = ODESystem(rc_eqs, t,
293-
systems=[resistor, capacitor, source, ground])
291+
@named rc_model = compose(ODESystem(rc_eqs, t)
292+
[resistor, capacitor, source, ground])
294293
```
295294
296-
Notice that this model is acasual because we have not specified anything about
297-
the causality of the model. We have simply specified what is true about each
298-
of the variables. This forms a system of differential-algebraic equations
299-
(DAEs) which define the evolution of each state of the system. The
300-
equations are:
295+
Note that we can also specify the subsystems in a vector. This model is acasual
296+
because we have not specified anything about the causality of the model. We have
297+
simply specified what is true about each of the variables. This forms a system
298+
of differential-algebraic equations (DAEs) which define the evolution of each
299+
state of the system. The equations are:
301300
302301
```julia
303302
equations(rc_model)
@@ -369,9 +368,10 @@ parameters(rc_model)
369368
This system could be solved directly as a DAE using [one of the DAE solvers
370369
from DifferentialEquations.jl](https://diffeq.sciml.ai/stable/solvers/dae_solve/).
371370
However, let's take a second to symbolically simplify the system before doing the
372-
solve. The function `structural_simplify` looks for all of the equalities and
373-
eliminates unnecessary variables to build the leanest numerical representation
374-
of the system. Let's see what it does here:
371+
solve. Although we can use ODE solvers that handles mass matrices to solve the
372+
above system directly, we want to run the `structural_simplify` function first,
373+
as it eliminates many unnecessary variables to build the leanest numerical
374+
representation of the system. Let's see what it does here:
375375
376376
```julia
377377
sys = structural_simplify(rc_model)
@@ -410,16 +410,13 @@ plot(sol)
410410
411411
![](https://user-images.githubusercontent.com/1814174/109416295-55184100-798b-11eb-96d1-5bb7e40135ba.png)
412412
413-
However, we can also choose to use the "torn nonlinear system" to remove all
414-
of the algebraic variables from the solution of the system. Note that this
415-
requires having done `structural_simplify`. MTK can numerically solve all
416-
the unreduced algebraic equations numerically. This is done by using
417-
`ODAEProblem` like:
413+
Since we have run `structural_simplify`, MTK can numerically solve all the
414+
unreduced algebraic equations numerically using the `ODAEProblem` (note the
415+
letter `A`):
418416
419417
```julia
420418
u0 = [
421419
capacitor.v => 0.0
422-
capacitor.p.i => 0.0
423420
]
424421
prob = ODAEProblem(sys, u0, (0, 10.0))
425422
sol = solve(prob, Rodas4())

examples/rc_model.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@ rc_eqs = [
1414
connect(capacitor.n, source.n, ground.g)
1515
]
1616

17-
@named rc_model = compose(ODESystem(rc_eqs, t), resistor, capacitor, source, ground)
17+
@named rc_model = compose(ODESystem(rc_eqs, t), [resistor, capacitor, source, ground])

src/systems/abstractsystem.jl

Lines changed: 73 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -597,13 +597,16 @@ function Base.show(io::IO, ::MIME"text/plain", sys::AbstractSystem)
597597
return nothing
598598
end
599599

600-
function _named(expr)
600+
function split_assign(expr)
601601
if !(expr isa Expr && expr.head === :(=) && expr.args[2].head === :call)
602602
throw(ArgumentError("expression should be of the form `sys = foo(a, b)`"))
603603
end
604604
name, call = expr.args
605+
end
605606

607+
function _named(name, call, runtime=false)
606608
has_kw = false
609+
call isa Expr || throw(Meta.ParseError("The rhs must be an Expr. Got $call."))
607610
if length(call.args) >= 2 && call.args[2] isa Expr
608611
# canonicalize to use `:parameters`
609612
if call.args[2].head === :kw
@@ -626,18 +629,78 @@ function _named(expr)
626629
kws = call.args[2].args
627630

628631
if !any(kw->(kw isa Symbol ? kw : kw.args[1]) == :name, kws) # don't overwrite `name` kwarg
629-
pushfirst!(kws, Expr(:kw, :name, Meta.quot(name)))
632+
pushfirst!(kws, Expr(:kw, :name, runtime ? name : Meta.quot(name)))
633+
end
634+
call
635+
end
636+
637+
function _named_idxs(name::Symbol, idxs, call)
638+
if call.head !== :->
639+
throw(ArgumentError("Not an anonymous function"))
640+
end
641+
if !isa(call.args[1], Symbol)
642+
throw(ArgumentError("not a single-argument anonymous function"))
630643
end
631-
:($name = $call)
644+
sym, ex = call.args
645+
ex = Base.Cartesian.poplinenum(ex)
646+
ex = _named(:(Symbol($(Meta.quot(name)), :_, $sym)), ex, true)
647+
ex = Base.Cartesian.poplinenum(ex)
648+
:($name = $map($sym->$ex, $idxs))
632649
end
633650

651+
check_name(name) = name isa Symbol || throw(Meta.ParseError("The lhs must be a symbol (a) or a ref (a[1:10]). Got $name."))
652+
634653
"""
635-
$(SIGNATURES)
654+
@named y = foo(x)
655+
@named y[1:10] = foo(x)
656+
@named y 1:10 i -> foo(x*i)
636657
637658
Rewrite `@named y = foo(x)` to `y = foo(x; name=:y)`.
659+
660+
Rewrite `@named y[1:10] = foo(x)` to `y = map(i′->foo(x; name=Symbol(:y_, i′)), 1:10)`.
661+
662+
Rewrite `@named y 1:10 i -> foo(x*i)` to `y = map(i->foo(x*i; name=Symbol(:y_, i)), 1:10)`.
663+
664+
Examples:
665+
```julia
666+
julia> using ModelingToolkit
667+
668+
julia> foo(i; name) = i, name
669+
foo (generic function with 1 method)
670+
671+
julia> x = 41
672+
41
673+
674+
julia> @named y = foo(x)
675+
(41, :y)
676+
677+
julia> @named y[1:3] = foo(x)
678+
3-element Vector{Tuple{Int64, Symbol}}:
679+
(41, :y_1)
680+
(41, :y_2)
681+
(41, :y_3)
682+
683+
julia> @named y 1:3 i -> foo(x*i)
684+
3-element Vector{Tuple{Int64, Symbol}}:
685+
(41, :y_1)
686+
(82, :y_2)
687+
(123, :y_3)
688+
```
638689
"""
639690
macro named(expr)
640-
esc(_named(expr))
691+
name, call = split_assign(expr)
692+
if Meta.isexpr(name, :ref)
693+
name, idxs = name.args
694+
check_name(name)
695+
esc(_named_idxs(name, idxs, :($(gensym()) -> $call)))
696+
else
697+
check_name(name)
698+
esc(:($name = $(_named(name, call))))
699+
end
700+
end
701+
702+
macro named(name::Symbol, idxs, call)
703+
esc(_named_idxs(name, idxs, call))
641704
end
642705

643706
function _config(expr, namespace)
@@ -854,15 +917,14 @@ Base.:(&)(sys::AbstractSystem, basesys::AbstractSystem; name::Symbol=nameof(sys)
854917
compose multiple systems together. The resulting system would inherit the first
855918
system's name.
856919
"""
857-
compose(syss::AbstractSystem...; name=nameof(first(syss))) = compose(collect(syss); name=name)
858-
function compose(syss::AbstractArray{<:AbstractSystem}; name=nameof(first(syss)))
859-
nsys = length(syss)
860-
nsys >= 2 || throw(ArgumentError("There must be at least 2 systems. Got $nsys systems."))
861-
sys = first(syss)
920+
function compose(sys::AbstractSystem, systems::AbstractArray{<:AbstractSystem}; name=nameof(first(syss)))
921+
nsys = length(systems)
922+
nsys >= 1 || throw(ArgumentError("There must be at least 1 subsystem. Got $nsys subsystems."))
862923
@set! sys.name = name
863-
@set! sys.systems = syss[2:end]
924+
@set! sys.systems = systems
864925
return sys
865926
end
927+
compose(syss::AbstractSystem...; name=nameof(first(syss))) = compose(first(syss), collect(syss[2:end]); name=name)
866928
Base.:()(sys1::AbstractSystem, sys2::AbstractSystem) = compose(sys1, sys2)
867929

868930
UnPack.unpack(sys::ModelingToolkit.AbstractSystem, ::Val{p}) where p = getproperty(sys, p; namespace=false)

test/components.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ sol = solve(prob, Rodas4())
2020
@test iszero(sol[ground.g.v])
2121
@test sol[resistor.v] == sol[source.p.v] - sol[capacitor.p.v]
2222

23+
u0 = [
24+
capacitor.v => 0.0
25+
]
2326
prob = ODAEProblem(sys, u0, (0, 10.0))
2427
sol = solve(prob, Tsit5())
2528

test/direct.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,9 @@ if VERSION >= v"1.5"
251251
@named cool_name = foo(;ff)
252252
@test collect(cool_name) == [pp; :ff => ff]
253253
end
254+
255+
foo(i; name) = i, name
256+
@named goo[1:3] = foo(10)
257+
@test isequal(goo, [(10, Symbol(:goo_, i)) for i in 1:3])
258+
@named koo 1:3 i -> foo(10i)
259+
@test isequal(koo, [(10i, Symbol(:koo_, i)) for i in 1:3])

0 commit comments

Comments
 (0)