Skip to content

Commit d9fb298

Browse files
Merge branch 'master' into hg/feature/varfuns
1 parent af476bd commit d9fb298

File tree

7 files changed

+40
-22
lines changed

7 files changed

+40
-22
lines changed

REQUIRE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
julia 1.0
1+
julia 1.1
22
MacroTools
33
DiffEqBase
44
DiffRules

src/differentials.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,10 @@ _repeat_apply(f, n) = n == 1 ? f : f ∘ _repeat_apply(f, n-1)
6161
function _differential_macro(x)
6262
ex = Expr(:block)
6363
lhss = Symbol[]
64+
x = x isa Tuple && first(x).head == :tuple ? first(x).args : x # tuple handling
6465
x = flatten_expr!(x)
6566
for di in x
66-
@assert di isa Expr && di.args[1] == :~ "@derivatives expects a form that looks like `@derivatives D''~t E'~t`"
67+
@assert di isa Expr && di.args[1] == :~ "@derivatives expects a form that looks like `@derivatives D''~t E'~t` or `@derivatives (D''~t), (E'~t)`"
6768
lhs = di.args[2]
6869
rhs = di.args[3]
6970
order, lhs = count_order(lhs)

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,10 @@ function generate_jacobian(sys::ODESystem; version::FunctionVersion = ArrayFunct
9999
return build_function(jac, sys.dvs, sys.ps, (sys.iv.name,); version = version)
100100
end
101101

102-
struct DiffEqToExpr
102+
struct ODEToExpr
103103
sys::ODESystem
104104
end
105-
function (f::DiffEqToExpr)(O::Operation)
105+
function (f::ODEToExpr)(O::Operation)
106106
if isa(O.op, Variable)
107107
isequal(O.op, f.sys.iv) && return O.op.name # independent variable
108108
O.op f.sys.dvs && return O.op.name # dependent variables
@@ -111,41 +111,39 @@ function (f::DiffEqToExpr)(O::Operation)
111111
end
112112
return build_expr(:call, Any[O.op; f.(O.args)])
113113
end
114-
(f::DiffEqToExpr)(x) = convert(Expr, x)
114+
(f::ODEToExpr)(x) = convert(Expr, x)
115115

116116
function generate_function(sys::ODESystem, dvs, ps; version::FunctionVersion = ArrayFunction)
117117
rhss = [deq.rhs for deq sys.eqs]
118118
dvs′ = [clean(dv) for dv dvs]
119119
ps′ = [clean(p) for p ps]
120-
return build_function(rhss, dvs′, ps′, (sys.iv.name,), DiffEqToExpr(sys); version = version)
120+
return build_function(rhss, dvs′, ps′, (sys.iv.name,), ODEToExpr(sys); version = version)
121121
end
122122

123123

124-
function generate_ode_iW(sys::ODESystem, simplify=true; version::FunctionVersion = ArrayFunction)
124+
function generate_factorized_W(sys::ODESystem, simplify=true; version::FunctionVersion = ArrayFunction)
125125
jac = calculate_jacobian(sys)
126126

127127
gam = Variable(:gam; known = true)()
128128

129129
W = LinearAlgebra.I - gam*jac
130-
W = SMatrix{size(W,1),size(W,2)}(W)
131-
iW = inv(W)
130+
Wfact = lu(W, Val(false), check=false).factors
132131

133132
if simplify
134-
iW = simplify_constants.(iW)
133+
Wfact = simplify_constants.(Wfact)
135134
end
136135

137-
W = inv(LinearAlgebra.I/gam - jac)
138-
W = SMatrix{size(W,1),size(W,2)}(W)
139-
iW_t = inv(W)
136+
W_t = LinearAlgebra.I/gam - jac
137+
Wfact_t = lu(W_t, Val(false), check=false).factors
140138
if simplify
141-
iW_t = simplify_constants.(iW_t)
139+
Wfact_t = simplify_constants.(Wfact_t)
142140
end
143141

144142
vs, ps = sys.dvs, sys.ps
145-
iW_func = build_function(iW , vs, ps, (:gam,:t); version = version)
146-
iW_t_func = build_function(iW_t, vs, ps, (:gam,:t); version = version)
143+
Wfact_func = build_function(Wfact , vs, ps, (:gam,:t), ODEToExpr(sys); version = version)
144+
Wfact_t_func = build_function(Wfact_t, vs, ps, (:gam,:t), ODEToExpr(sys); version = version)
147145

148-
return (iW_func, iW_t_func)
146+
return (Wfact_func, Wfact_t_func)
149147
end
150148

151149
function DiffEqBase.ODEFunction(sys::ODESystem, dvs, ps; version::FunctionVersion = ArrayFunction)

src/variables.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,12 @@ function _parse_vars(macroname, known, x)
3838
# y
3939
# z
4040
# end
41+
x = x isa Tuple && first(x) isa Expr && first(x).head == :tuple ? first(x).args : x # tuple handling
4142
x = flatten_expr!(x)
4243
for _var in x
4344
iscall = isa(_var, Expr) && _var.head == :call
4445
issym = _var isa Symbol
45-
@assert iscall || issym "@$macroname expects a tuple of expressions (`@$macroname x y z(t)`)"
46+
@assert iscall || issym "@$macroname expects a tuple of expressions or an expression of a tuple (`@$macroname x y z(t)` or `@$macroname x, y, z(t)`)"
4647

4748
if iscall
4849
var_name = _var.args[1]
@@ -59,8 +60,8 @@ function _parse_vars(macroname, known, x)
5960
return ex
6061
end
6162
macro variables(xs...)
62-
esc(_parse_vars(:Variable, false, xs))
63+
esc(_parse_vars(:variables, false, xs))
6364
end
6465
macro parameters(xs...)
65-
esc(_parse_vars(:Param, true, xs))
66+
esc(_parse_vars(:parameters, true, xs))
6667
end

test/derivatives.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ using Test
66
@variables x(t) y(t) z(t)
77
@derivatives D'~t D2''~t Dx'~x
88

9+
@test @macroexpand(@derivatives D'~t D2''~t) == @macroexpand(@derivatives (D'~t), (D2''~t))
10+
911
@test isequal(expand_derivatives(D(t)), 1)
1012
@test isequal(expand_derivatives(D(D(t))), 0)
1113

test/system_construction.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,15 @@ generate_function(de, [x,y,z], [σ,ρ,β]; version=ModelingToolkit.SArrayFunctio
3636
jac_expr = generate_jacobian(de)
3737
jac = calculate_jacobian(de)
3838
f = ODEFunction(de, [x,y,z], [σ,ρ,β])
39-
ModelingToolkit.generate_ode_iW(de)
39+
fw, fwt = map(eval, ModelingToolkit.generate_factorized_W(de))
40+
du = zeros(3)
41+
u = collect(1:3)
42+
p = collect(4:6)
43+
f(du, u, p, 0.1)
44+
@test du == [4, 0, -16]
45+
FW = zeros(3, 3)
46+
fw(FW, u, p, 0.2, 0.1)
47+
fwt(FW, u, p, 0.2, 0.1)
4048

4149
@testset "time-varying parameters" begin
4250
@parameters σ′(t-1)
@@ -90,6 +98,10 @@ lowered_eqs = [D(u_tt) ~ 2u_tt + u_t + x_t + 1
9098
D(u) ~ u_t
9199
D(x) ~ x_t]
92100
@test de1 == ODESystem(lowered_eqs)
101+
test_diffeq_inference("first-order transform", de1, t, [u_tt, x_t, u_t, u, x], [])
102+
du = zeros(5)
103+
ODEFunction(de1, [u_tt, x_t, u_t, u, x], [])(du, ones(5), nothing, 0.1)
104+
@test du == [5.0, 3.0, 1.0, 1.0, 1.0]
93105

94106
# Internal calculations
95107
a = y - x

test/variable_parsing.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ using ModelingToolkit
22
using Test
33

44
@parameters t()
5-
@variables x(t) y(t) z(t)
5+
@variables x(t) y(t) # test multi-arg
6+
@variables z(t) # test single-arg
67
x1 = Variable(:x)(t)
78
y1 = Variable(:y)(t)
89
z1 = Variable(:z)(t)
@@ -27,3 +28,6 @@ D1 = Differential(t)
2728
@test D1 == D
2829

2930
@test isequal(x y + 1, (x < y + 1) | (x == y + 1))
31+
32+
@test @macroexpand(@parameters x, y, z(t)) == @macroexpand(@parameters x y z(t))
33+
@test @macroexpand(@variables x, y, z(t)) == @macroexpand(@variables x y z(t))

0 commit comments

Comments
 (0)