Skip to content

Commit f81a613

Browse files
authored
Merge pull request #1681 from SciML/s/calling-convention-update
Use new dependent array variable convention
2 parents bb0fa0a + ed4c859 commit f81a613

File tree

13 files changed

+63
-41
lines changed

13 files changed

+63
-41
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ Setfield = "0.7, 0.8, 1"
7373
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
7474
StaticArrays = "0.10, 0.11, 0.12, 1.0"
7575
SymbolicUtils = "0.19"
76-
Symbolics = "4.5 - 4.8"
76+
Symbolics = "4.9"
7777
UnPack = "0.1, 1.0"
7878
Unitful = "1.1"
7979
julia = "1.6"

docs/src/tutorials/spring_mass.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ D = Differential(t)
1313
1414
function Mass(; name, m = 1.0, xy = [0., 0.], u = [0., 0.])
1515
ps = @parameters m=m
16-
sts = @variables pos[1:2](t)=xy v[1:2](t)=u
16+
sts = @variables pos(t)[1:2]=xy v(t)[1:2]=u
1717
eqs = scalarize(D.(pos) .~ v)
1818
ODESystem(eqs, t, [pos..., v...], ps; name)
1919
end
2020
2121
function Spring(; name, k = 1e4, l = 1.)
2222
ps = @parameters k=k l=l
23-
@variables x(t), dir[1:2](t)
23+
@variables x(t), dir(t)[1:2]
2424
ODESystem(Equation[], t, [x, dir...], ps; name)
2525
end
2626
@@ -63,7 +63,7 @@ For each component we use a Julia function that returns an `ODESystem`. At the t
6363
```@example component
6464
function Mass(; name, m = 1.0, xy = [0., 0.], u = [0., 0.])
6565
ps = @parameters m=m
66-
sts = @variables pos[1:2](t)=xy v[1:2](t)=u
66+
sts = @variables pos(t)[1:2]=xy v(t)[1:2]=u
6767
eqs = scalarize(D.(pos) .~ v)
6868
ODESystem(eqs, t, [pos..., v...], ps; name)
6969
end
@@ -86,7 +86,7 @@ Next we build the spring component. It is characterised by the spring constant `
8686
```@example component
8787
function Spring(; name, k = 1e4, l = 1.)
8888
ps = @parameters k=k l=l
89-
@variables x(t), dir[1:2](t)
89+
@variables x(t), dir(t)[1:2]
9090
ODESystem(Equation[], t, [x, dir...], ps; name)
9191
end
9292
```

src/parameters.jl

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
1-
import SymbolicUtils: symtype, term, hasmetadata
1+
import SymbolicUtils: symtype, term, hasmetadata, issym
22
struct MTKParameterCtx end
33

4-
isparameter(x::Num) = isparameter(value(x))
5-
isparameter(x::Symbolic) = getmetadata(x, MTKParameterCtx, false)
6-
isparameter(x) = false
4+
function isparameter(x)
5+
x = unwrap(x)
6+
if istree(x) && operation(x) isa Symbolic
7+
getmetadata(x, MTKParameterCtx, false) ||
8+
isparameter(operation(x))
9+
elseif istree(x) && operation(x) == (getindex)
10+
isparameter(arguments(x)[1])
11+
elseif x isa Symbolic
12+
getmetadata(x, MTKParameterCtx, false)
13+
else
14+
false
15+
end
16+
end
717

818
"""
919
toparam(s::Sym)
@@ -15,13 +25,11 @@ function toparam(s)
1525
Symbolics.wrap(toparam(Symbolics.unwrap(s)))
1626
elseif s isa AbstractArray
1727
map(toparam, s)
18-
elseif symtype(s) <: AbstractArray
19-
Symbolics.recurse_and_apply(toparam, s)
2028
else
2129
setmetadata(s, MTKParameterCtx, true)
2230
end
2331
end
24-
toparam(s::Num) = Num(toparam(value(s)))
32+
toparam(s::Num) = wrap(toparam(value(s)))
2533

2634
"""
2735
tovar(s::Sym)

src/systems/connectors.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -282,9 +282,7 @@ function generate_connection_equations_and_stream_connections(csets::AbstractVec
282282

283283
for cset in csets
284284
v = cset.set[1].v
285-
if hasmetadata(v, Symbolics.GetindexParent)
286-
v = getparent(v)
287-
end
285+
v = getparent(v, v)
288286
vtype = get_connection_type(v)
289287
if vtype === Stream
290288
push!(stream_connections, cset)

src/utils.jl

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ function check_variables(dvs, iv)
148148
for dv in dvs
149149
isequal(iv, dv) &&
150150
throw(ArgumentError("Independent variable $iv not allowed in dependent variables."))
151-
(is_delay_var(iv, dv) || occursin(iv, iv_from_nested_derivative(dv))) ||
151+
(is_delay_var(iv, dv) || occursin(iv, dv)) ||
152152
throw(ArgumentError("Variable $dv is not a function of independent variable $iv."))
153153
end
154154
end
@@ -201,19 +201,35 @@ function check_equations(eqs, iv)
201201
end
202202
end
203203
"Get all the independent variables with respect to which differentials/differences are taken."
204-
function collect_ivs_from_nested_operator!(ivs, x::Term, target_op)
205-
op = operation(x)
204+
function collect_ivs_from_nested_operator!(ivs, x, target_op)
205+
if !istree(x)
206+
return
207+
end
208+
op = operation(unwrap(x))
206209
if op isa target_op
207210
push!(ivs, get_iv(op))
208-
collect_ivs_from_nested_operator!(ivs, arguments(x)[1], target_op)
211+
x = if target_op <: Differential
212+
op.x
213+
elseif target_op <: Difference
214+
op.t
215+
else
216+
error("Unknown target op type in collect_ivs $target_op. Pass Difference or Differential")
217+
end
218+
collect_ivs_from_nested_operator!(ivs, x, target_op)
209219
end
210220
end
211221

212-
function iv_from_nested_derivative(x::Term, op = Differential)
213-
operation(x) isa op ? iv_from_nested_derivative(arguments(x)[1], op) : arguments(x)[1]
222+
function iv_from_nested_derivative(x, op = Differential)
223+
if istree(x) && operation(x) == getindex
224+
iv_from_nested_derivative(arguments(x)[1], op)
225+
elseif istree(x)
226+
operation(x) isa op ? iv_from_nested_derivative(arguments(x)[1], op) : arguments(x)[1]
227+
elseif issym(x)
228+
x
229+
else
230+
nothing
231+
end
214232
end
215-
iv_from_nested_derivative(x::Sym, op = Differential) = x
216-
iv_from_nested_derivative(x, op = Differential) = nothing
217233

218234
hasdefault(v) = hasmetadata(v, Symbolics.VariableDefaultValue)
219235
getdefault(v) = value(getmetadata(v, Symbolics.VariableDefaultValue))

test/latexify.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ eqs = [D(x) ~ σ * (y - x) * D(x - y) / D(z),
3030
# Latexify.@generate_test latexify(eqs)
3131
@test_reference "latexify/10.tex" latexify(eqs)
3232

33-
@variables u[1:3](t)
33+
@variables u(t)[1:3]
3434
@parameters p[1:3]
3535
eqs = [D(u[1]) ~ p[3] * (u[2] - u[1]),
3636
0 ~ p[2] * p[3] * u[1] * (p[1] - u[1]) / 10 - u[2],

test/latexify/20.tex

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
\begin{align}
2-
\frac{du_1(t)}{dt} =& \left( - \mathrm{u_1}\left( t \right) + \mathrm{u_2}\left( t \right) \right) p_3 \\
3-
0 =& - \mathrm{u_2}\left( t \right) + \frac{1}{10} \left( - \mathrm{u_1}\left( t \right) + p_1 \right) \mathrm{u_1}\left( t \right) p_2 p_3 \\
4-
\frac{du_3(t)}{dt} =& \left( \mathrm{u_2}\left( t \right) \right)^{\frac{2}{3}} \mathrm{u_1}\left( t \right) - \mathrm{u_3}\left( t \right) p_3
2+
\mathrm{\frac{d}{d t}}\left( u(t)_1 \right) =& \left( - u(t)_1 + u(t)_2 \right) p_3 \\
3+
0 =& - u(t)_2 + \frac{1}{10} \left( - u(t)_1 + p_1 \right) p_2 p_3 u(t)_1 \\
4+
\mathrm{\frac{d}{d t}}\left( u(t)_3 \right) =& u(t)_2^{\frac{2}{3}} u(t)_1 - p_3 u(t)_3
55
\end{align}

test/latexify/30.tex

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
\begin{align}
2-
\frac{du_1(t)}{dt} =& \left( - \mathrm{u_1}\left( t \right) + \mathrm{u_2}\left( t \right) \right) p_3 \\
3-
\frac{du_2(t)}{dt} =& - \mathrm{u_2}\left( t \right) + \frac{1}{10} \left( - \mathrm{u_1}\left( t \right) + p_1 \right) \mathrm{u_1}\left( t \right) p_2 p_3 \\
4-
\frac{du_3(t)}{dt} =& \left( \mathrm{u_2}\left( t \right) \right)^{\frac{2}{3}} \mathrm{u_1}\left( t \right) - \mathrm{u_3}\left( t \right) p_3
2+
\mathrm{\frac{d}{d t}}\left( u(t)_1 \right) =& \left( - u(t)_1 + u(t)_2 \right) p_3 \\
3+
\mathrm{\frac{d}{d t}}\left( u(t)_2 \right) =& - u(t)_2 + \frac{1}{10} \left( - u(t)_1 + p_1 \right) p_2 p_3 u(t)_1 \\
4+
\mathrm{\frac{d}{d t}}\left( u(t)_3 \right) =& u(t)_2^{\frac{2}{3}} u(t)_1 - p_3 u(t)_3
55
\end{align}

test/mass_matrix.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using OrdinaryDiffEq, ModelingToolkit, Test, LinearAlgebra
22
@parameters t
3-
@variables y[1:3](t)
3+
@variables y(t)[1:3]
44
@parameters k[1:3]
55
D = Differential(t)
66

test/odesystem.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ end
376376

377377
# issue 1109
378378
let
379-
@variables t x[1:3, 1:3](t)
379+
@variables t x(t)[1:3, 1:3]
380380
D = Differential(t)
381381
@named sys = ODESystem(D.(x) .~ x)
382382
@test_nowarn structural_simplify(sys)
@@ -386,7 +386,7 @@ end
386386
using Symbolics: unwrap, wrap
387387
using LinearAlgebra
388388
@variables t
389-
sts = @variables x[1:3](t)=[1, 2, 3.0] y(t)=1.0
389+
sts = @variables x(t)[1:3]=[1, 2, 3.0] y(t)=1.0
390390
ps = @parameters p[1:3] = [1, 2, 3]
391391
D = Differential(t)
392392
eqs = [collect(D.(x) .~ x)
@@ -487,7 +487,7 @@ function foo(a::Num, ms::AbstractVector)
487487
wrap(term(foo, a, term(SVector, ms...)))
488488
end
489489
foo(a, ms::AbstractVector) = a + sum(ms)
490-
@variables t x(t) ms[1:3](t)
490+
@variables t x(t) ms(t)[1:3]
491491
D = Differential(t)
492492
ms = collect(ms)
493493
eqs = [D(x) ~ foo(x, ms); D.(ms) .~ 1]
@@ -589,7 +589,7 @@ end
589589
let
590590
@parameters t
591591
D = Differential(t)
592-
@variables x[1:2](t) = zeros(2)
592+
@variables x(t)[1:2] = zeros(2)
593593
@variables y(t) = 0
594594
@parameters k = 1
595595
eqs = [D(x[1]) ~ x[2]
@@ -680,7 +680,7 @@ let
680680
eqs_to_lhs(eqs) = eq_to_lhs.(eqs)
681681

682682
@parameters σ=10 ρ=28 β=8 / 3 sigma rho beta
683-
@variables t t2 x(t)=1 y(t)=0 z(t)=0 x2(t2)=1 y2(t2)=0 z2(t2)=0 u[1:3](t2)
683+
@variables t t2 x(t)=1 y(t)=0 z(t)=0 x2(t2)=1 y2(t2)=0 z2(t2)=0 u(t2)[1:3]
684684

685685
D = Differential(t)
686686
D2 = Differential(t2)
@@ -749,7 +749,7 @@ end
749749
let
750750
@parameters t
751751

752-
u = collect(first(@variables u[1:4](t)))
752+
u = collect(first(@variables u(t)[1:4]))
753753
Dt = Differential(t)
754754

755755
eqs = [Differential(t)(u[2]) - 1.1u[1] ~ 0

0 commit comments

Comments
 (0)