Skip to content

Commit 4e355c2

Browse files
committed
Merge remote-tracking branch 'origin/master' into s/unflatten
2 parents add4021 + 929f6bc commit 4e355c2

File tree

5 files changed

+96
-10
lines changed

5 files changed

+96
-10
lines changed

src/systems/diffeqs/modelingtoolkitize.jl

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,16 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem)
77
prob.f isa DiffEqBase.AbstractParameterizedFunction &&
88
return (prob.f.sys, prob.f.sys.states, prob.f.sys.ps)
99
@parameters t
10+
11+
if prob.p isa Tuple || prob.p isa NamedTuple
12+
p = [x for x in prob.p]
13+
else
14+
p = prob.p
15+
end
16+
1017
vars = reshape([Variable(:x, i)(t) for i in eachindex(prob.u0)],size(prob.u0))
11-
params = prob.p isa DiffEqBase.NullParameters ? [] :
12-
reshape([Variable(,i)() for i in eachindex(prob.p)],size(prob.p))
18+
params = p isa DiffEqBase.NullParameters ? [] :
19+
reshape([Variable(,i)() for i in eachindex(p)],size(Array(p)))
1320
@derivatives D'~t
1421

1522
rhs = [D(var) for var in vars]
@@ -38,9 +45,16 @@ function modelingtoolkitize(prob::DiffEqBase.SDEProblem)
3845
prob.f isa DiffEqBase.AbstractParameterizedFunction &&
3946
return (prob.f.sys, prob.f.sys.states, prob.f.sys.ps)
4047
@parameters t
48+
49+
if prob.p isa Tuple || prob.p isa NamedTuple
50+
p = [x for x in prob.p]
51+
else
52+
p = prob.p
53+
end
54+
4155
vars = reshape([Variable(:x, i)(t) for i in eachindex(prob.u0)],size(prob.u0))
42-
params = prob.p isa DiffEqBase.NullParameters ? [] :
43-
reshape([Variable(,i)() for i in eachindex(prob.p)],size(prob.p))
56+
params = p isa DiffEqBase.NullParameters ? [] :
57+
reshape([Variable(,i)() for i in eachindex(p)],size(p))
4458
@derivatives D'~t
4559

4660
rhs = [D(var) for var in vars]
@@ -65,7 +79,7 @@ function modelingtoolkitize(prob::DiffEqBase.SDEProblem)
6579
end
6680
end
6781
deqs = vcat([rhs[i] ~ lhs[i] for i in eachindex(prob.u0)]...)
68-
82+
6983
de = SDESystem(deqs,neqs,t,vec(vars),vec(params))
7084

7185
de

src/systems/reduction.jl

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
export alias_elimination
22

33
function flatten(sys::ODESystem)
4-
ODESystem(equations(sys),
5-
independent_variable(sys),
6-
observed=observed(sys))
4+
if isempty(sys.systems)
5+
return sys
6+
else
7+
return ODESystem(equations(sys),
8+
independent_variable(sys),
9+
observed=observed(sys))
10+
end
711
end
812

913

@@ -32,8 +36,16 @@ function make_lhs_0(eq)
3236
end
3337

3438
function alias_elimination(sys::ODESystem)
35-
eqs = vcat(equations(sys),
36-
make_lhs_0.(observed(sys)))
39+
eqs = vcat(equations(sys), observed(sys))
40+
41+
# make all algebraic equations have 0 on LHS
42+
eqs = map(eqs) do eq
43+
if eq.lhs isa Operation && eq.lhs.op isa Differential
44+
eq
45+
else
46+
make_lhs_0(eq)
47+
end
48+
end
3749

3850
new_stateops = map(eqs) do eq
3951
if eq.lhs isa Operation && eq.lhs.op isa Differential
@@ -49,11 +61,13 @@ function alias_elimination(sys::ODESystem)
4961

5062
newstates = convert.(Variable, new_stateops)
5163

64+
5265
alg_idxs = findall(x->x.lhs isa Constant && iszero(x.lhs), eqs)
5366

5467
eliminate = setdiff(convert.(Variable, all_vars), newstates)
5568

5669
vars = map(x->x(sys.iv()), eliminate)
70+
5771
outputs = solve_for(eqs[alg_idxs], vars)
5872

5973
diffeqs = eqs[setdiff(1:length(eqs), alg_idxs)]

test/modelingtoolkitize.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
using OrdinaryDiffEq, ModelingToolkit
2+
const N = 32
3+
const xyd_brusselator = range(0,stop=1,length=N)
4+
brusselator_f(x, y, t) = (((x-0.3)^2 + (y-0.6)^2) <= 0.1^2) * (t >= 1.1) * 5.
5+
limit(a, N) = ModelingToolkit.ifelse(a == N+1, 1, ModelingToolkit.ifelse(a == 0, N, a))
6+
function brusselator_2d_loop(du, u, p, t)
7+
A, B, alpha, dx = p
8+
alpha = alpha/dx^2
9+
@inbounds for I in CartesianIndices((N, N))
10+
i, j = Tuple(I)
11+
x, y = xyd_brusselator[I[1]], xyd_brusselator[I[2]]
12+
ip1, im1, jp1, jm1 = limit(i+1, N), limit(i-1, N), limit(j+1, N), limit(j-1, N)
13+
du[i,j,1] = alpha*(u[im1,j,1] + u[ip1,j,1] + u[i,jp1,1] + u[i,jm1,1] - 4u[i,j,1]) +
14+
B + u[i,j,1]^2*u[i,j,2] - (A + 1)*u[i,j,1] + brusselator_f(x, y, t)
15+
du[i,j,2] = alpha*(u[im1,j,2] + u[ip1,j,2] + u[i,jp1,2] + u[i,jm1,2] - 4u[i,j,2]) +
16+
A*u[i,j,1] - u[i,j,1]^2*u[i,j,2]
17+
end
18+
end
19+
20+
# Test with tuple parameters
21+
p = (3.4, 1., 10., step(xyd_brusselator))
22+
23+
function init_brusselator_2d(xyd)
24+
N = length(xyd)
25+
u = zeros(N, N, 2)
26+
for I in CartesianIndices((N, N))
27+
x = xyd[I[1]]
28+
y = xyd[I[2]]
29+
u[I,1] = 22*(y*(1-y))^(3/2)
30+
u[I,2] = 27*(x*(1-x))^(3/2)
31+
end
32+
u
33+
end
34+
u0 = init_brusselator_2d(xyd_brusselator)
35+
36+
# Test with 3-tensor inputs
37+
prob_ode_brusselator_2d = ODEProblem(brusselator_2d_loop,
38+
u0,(0.,11.5),p)
39+
40+
modelingtoolkitize(prob_ode_brusselator_2d)

test/reduction.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,20 @@ test_equal.(observed(aliased_flattened_system), [
9595
a ~ lorenz2.y + -1 * lorenz1.x,
9696
lorenz2.u ~ lorenz2.x + -1 * (lorenz2.y + lorenz2.z),
9797
])
98+
99+
100+
# issue #578
101+
102+
let
103+
@variables t x(t) y(t) z(t);
104+
@derivatives D'~t;
105+
eqs = [
106+
D(x) ~ x + y
107+
x ~ y
108+
];
109+
sys = ODESystem(eqs, t, [x], []);
110+
asys = alias_elimination(ModelingToolkit.flatten(sys))
111+
112+
test_equal.(asys.eqs, [D(x) ~ 2x])
113+
test_equal.(asys.observed, [y ~ x])
114+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ using SafeTestsets, Test
1919
@safetestset "ControlSystem Test" begin include("controlsystem.jl") end
2020
@safetestset "Build Targets Test" begin include("build_targets.jl") end
2121
@safetestset "Domain Test" begin include("domains.jl") end
22+
@safetestset "Modelingtoolkitize Test" begin include("modelingtoolkitize.jl") end
2223
@safetestset "Constraints Test" begin include("constraints.jl") end
2324
@safetestset "Reduction Test" begin include("reduction.jl") end
2425
@safetestset "PDE Construction Test" begin include("pde.jl") end

0 commit comments

Comments
 (0)