Skip to content

Commit 3926542

Browse files
Fix generate_function
1 parent d40dbe0 commit 3926542

File tree

3 files changed

+31
-18
lines changed

3 files changed

+31
-18
lines changed

src/equations.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ Base.:~(lhs::Expression, rhs::Expression) = Equation(lhs, rhs)
1111
Base.:~(lhs::Expression, rhs::Number ) = Equation(lhs, rhs)
1212
Base.:~(lhs::Number , rhs::Expression) = Equation(lhs, rhs)
1313

14-
_is_parameter(iv) = (O::Operation) -> O.op.known && !isequal(O, iv)
1514
_is_known(O::Operation) = O.op.known
1615
_is_unknown(O::Operation) = !O.op.known
1716

@@ -28,8 +27,6 @@ function extract_elements(eqs, predicates)
2827
return result
2928
end
3029

31-
get_args(O::Operation) = O.args
32-
get_args(eq::Equation) = Expression[eq.lhs, eq.rhs]
3330
vars(exprs) = foldl(vars!, exprs; init = Set{Variable}())
3431
function vars!(vars, O)
3532
isa(O, Operation) || return vars

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,11 @@ function (f::DiffEqToExpr)(O::Operation)
8686
end
8787
(f::DiffEqToExpr)(x) = convert(Expr, x)
8888

89-
function generate_function(sys::DiffEqSystem; version::FunctionVersion = ArrayFunction)
89+
function generate_function(sys::DiffEqSystem, vs, ps; version::FunctionVersion = ArrayFunction)
9090
rhss = [deq.rhs for deq sys.eqs]
91-
return build_function(rhss, sys.dvs, sys.ps, (sys.iv.name,), DiffEqToExpr(sys); version = version)
91+
vs′ = [clean(v) for v vs]
92+
ps′ = [clean(p) for p ps]
93+
return build_function(rhss, vs′, ps′, (sys.iv.name,), DiffEqToExpr(sys); version = version)
9294
end
9395

9496

test/system_construction.jl

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@ eqs = [D(x) ~ σ*(y-x),
2020
D(z) ~ x*y - β*z]
2121
de = DiffEqSystem(eqs)
2222
test_diffeq_inference("standard", de, :t, (:x, :y, :z), (, , ))
23-
generate_function(de)
24-
generate_function(de; version=ModelingToolkit.SArrayFunction)
23+
generate_function(de, [x,y,z], [σ,ρ,β])
24+
generate_function(de, [x,y,z], [σ,ρ,β]; version=ModelingToolkit.SArrayFunction)
2525
jac_expr = generate_jacobian(de)
2626
jac = calculate_jacobian(de)
27-
f = ODEFunction(de)
27+
@test_broken begin
28+
f = ODEFunction(de)
29+
end
2830
ModelingToolkit.generate_ode_iW(de)
2931

3032
@testset "time-varying parameters" begin
@@ -35,24 +37,34 @@ ModelingToolkit.generate_ode_iW(de)
3537
de = DiffEqSystem(eqs)
3638
test_diffeq_inference("global iv-varying", de, :t, (:x, :y, :z), (:σ′, , ))
3739
@test begin
38-
f = eval(generate_function(de))
40+
f = eval(generate_function(de, [x,y,z], [σ′,ρ,β]))
3941
du = [0.0,0.0,0.0]
4042
f(du, [1.0,2.0,3.0], [x->x+7,2,3], 5.0)
41-
du [12, -3, -7]
43+
du [11, -3, -7]
4244
end
4345

4446
@parameters σ
4547
eqs = [D(x) ~ σ(t-1)*(y-x),
4648
D(y) ~ x*-z)-y,
4749
D(z) ~ x*y - β*z]
4850
de = DiffEqSystem(eqs)
49-
test_diffeq_inference("internal iv-varying", de, :t, (:x, :y, :z), (, , ))
51+
test_diffeq_inference("single internal iv-varying", de, :t, (:x, :y, :z), (, , ))
5052
@test begin
51-
f = eval(generate_function(de))
53+
f = eval(generate_function(de, [x,y,z], [σ,ρ,β]))
5254
du = [0.0,0.0,0.0]
5355
f(du, [1.0,2.0,3.0], [x->x+7,2,3], 5.0)
5456
du [11, -3, -7]
5557
end
58+
59+
eqs = [D(x) ~ x + 10σ(t-1) + 100σ(t-2) + 1000σ(t^2)]
60+
de = DiffEqSystem(eqs)
61+
test_diffeq_inference("many internal iv-varying", de, :t, (:x,), (,))
62+
@test begin
63+
f = eval(generate_function(de, [x], [σ]))
64+
du = [0.0]
65+
f(du, [1.0], [t -> t + 2], 5.0)
66+
du [27561]
67+
end
5668
end
5769

5870
@test_broken begin
@@ -78,9 +90,11 @@ eqs = [D(x) ~ σ*a,
7890
D(y) ~ x*-z)-y,
7991
D(z) ~ x*y - β*z]
8092
de = DiffEqSystem(eqs)
81-
generate_function(de)
93+
generate_function(de, [x,y,z], [σ,ρ,β])
8294
jac = calculate_jacobian(de)
83-
f = ODEFunction(de)
95+
@test_broken begin
96+
f = ODEFunction(de)
97+
end
8498

8599
@test_broken begin
86100
# Define a nonlinear system
@@ -95,7 +109,7 @@ for el in (:vs, :ps)
95109
@test names2 == names
96110
end
97111

98-
generate_function(ns)
112+
generate_function(ns, [x,y,z], [σ,ρ,β])
99113
end
100114

101115
@derivatives D'~t
@@ -105,7 +119,7 @@ eqs = [D(x) ~ -A*x,
105119
D(y) ~ A*x - B*_x]
106120
de = DiffEqSystem(eqs)
107121
@test begin
108-
f = eval(generate_function(de))
122+
f = eval(generate_function(de, [x,y], [A,B,C]))
109123
du = [0.0,0.0]
110124
f(du, [1.0,2.0], [1,2,3], 0.0)
111125
du [-1, -1/3]
@@ -133,7 +147,7 @@ jac = calculate_jacobian(ns)
133147
@test isequal(jac[3,2], x)
134148
@test isequal(jac[3,3], -1 * β)
135149
end
136-
nlsys_func = generate_function(ns)
150+
nlsys_func = generate_function(ns, [x,y,z], [σ,ρ,β])
137151
jac_func = generate_jacobian(ns)
138152
f = @eval eval(nlsys_func)
139153

@@ -143,7 +157,7 @@ eqs = [0 ~ σ*a,
143157
0 ~ x*-z)-y,
144158
0 ~ x*y - β*z]
145159
ns = NonlinearSystem(eqs,[x,y,z],[σ,ρ,β])
146-
nlsys_func = generate_function(ns)
160+
nlsys_func = generate_function(ns, [x,y,z], [σ,ρ,β])
147161
jac = calculate_jacobian(ns)
148162
jac = generate_jacobian(ns)
149163
end

0 commit comments

Comments
 (0)