Skip to content

Commit ee83816

Browse files
Remove ability to store multiple independent variables
Not currently supported.
1 parent d5ff6cf commit ee83816

File tree

5 files changed

+23
-21
lines changed

5 files changed

+23
-21
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ Each operation builds an `Operation` type, and thus `eqs` is an array of
4646
analyzed by other programs. We can turn this into a `DiffEqSystem` via:
4747

4848
```julia
49-
de = DiffEqSystem(eqs,[t],[x,y,z],[σ,ρ,β])
49+
de = DiffEqSystem(eqs,t,[x,y,z],[σ,ρ,β])
5050
de = DiffEqSystem(eqs)
5151
```
5252

src/equations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Base.:~(lhs::Number , rhs::Expression) = Equation(lhs, rhs)
1313

1414

1515
_is_dependent(x::Variable) = x.subtype === :Unknown && !isempty(x.dependents)
16-
_is_parameter(ivs) = x -> x.subtype === :Parameter && x ivs
16+
_is_parameter(iv) = x -> x.subtype === :Parameter && x iv
1717
_subtype(subtype::Symbol) = x -> x.subtype === subtype
1818

1919
function extract_elements(eqs, predicates)

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,28 @@ get_args(eq::DiffEq) = Expression[eq.var, eq.rhs]
1616

1717
struct DiffEqSystem <: AbstractSystem
1818
eqs::Vector{DiffEq}
19-
ivs::Vector{Variable}
19+
iv::Variable
2020
dvs::Vector{Variable}
2121
ps::Vector{Variable}
2222
jac::RefValue{Matrix{Expression}}
23-
function DiffEqSystem(eqs, ivs, dvs, ps)
23+
function DiffEqSystem(eqs, iv, dvs, ps)
2424
jac = RefValue(Matrix{Expression}(undef, 0, 0))
25-
new(eqs, ivs, dvs, ps, jac)
25+
new(eqs, iv, dvs, ps, jac)
2626
end
2727
end
2828

2929
function DiffEqSystem(eqs)
3030
dvs, = extract_elements(eqs, [_is_dependent])
3131
ivs = unique(vcat((dv.dependents for dv dvs)...))
32-
ps, = extract_elements(eqs, [_is_parameter(ivs)])
33-
DiffEqSystem(eqs, ivs, dvs, ps)
32+
length(ivs) == 1 || throw(ArgumentError("one independent variable currently supported"))
33+
iv = first(ivs)
34+
ps, = extract_elements(eqs, [_is_parameter(iv)])
35+
DiffEqSystem(eqs, iv, dvs, ps)
3436
end
3537

36-
function DiffEqSystem(eqs, ivs)
37-
dvs, ps = extract_elements(eqs, [_is_dependent, _is_parameter(ivs)])
38-
DiffEqSystem(eqs, ivs, dvs, ps)
38+
function DiffEqSystem(eqs, iv)
39+
dvs, ps = extract_elements(eqs, [_is_dependent, _is_parameter(iv)])
40+
DiffEqSystem(eqs, iv, dvs, ps)
3941
end
4042

4143

@@ -44,12 +46,12 @@ function generate_ode_function(sys::DiffEqSystem; version::FunctionVersion = Arr
4446
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in eachindex(sys.ps)]
4547
sys_exprs = build_equals_expr.(sys.eqs)
4648
if version === ArrayFunction
47-
dvar_exprs = [:(du[$i] = $(Symbol("$(sys.dvs[i].name)_$(sys.ivs[1].name)"))) for i in eachindex(sys.dvs)]
49+
dvar_exprs = [:(du[$i] = $(Symbol("$(sys.dvs[i].name)_$(sys.iv.name)"))) for i in eachindex(sys.dvs)]
4850
exprs = vcat(var_exprs,param_exprs,sys_exprs,dvar_exprs)
4951
block = expr_arr_to_block(exprs)
5052
:((du,u,p,t)->$(toexpr(block)))
5153
elseif version === SArrayFunction
52-
dvar_exprs = [:($(Symbol("$(sys.dvs[i].name)_$(sys.ivs[1].name)"))) for i in eachindex(sys.dvs)]
54+
dvar_exprs = [:($(Symbol("$(sys.dvs[i].name)_$(sys.iv.name)"))) for i in eachindex(sys.dvs)]
5355
svector_expr = quote
5456
E = eltype(tuple($(dvar_exprs...)))
5557
T = StaticArrays.similar_type(typeof(u), E)

src/systems/diffeqs/first_order_transform.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,8 @@ end
1212

1313
function ode_order_lowering(sys::DiffEqSystem; kwargs...)
1414
eqs = sys.eqs
15-
ivs = sys.ivs
1615
eqs_lowered = ode_order_lowering(eqs; kwargs...)
17-
DiffEqSystem(eqs_lowered, ivs)
16+
DiffEqSystem(eqs_lowered, sys.iv)
1817
end
1918
ode_order_lowering(eqs; naming_scheme = "_") = ode_order_lowering!(deepcopy(eqs), naming_scheme)
2019
function ode_order_lowering!(eqs, naming_scheme)

test/system_construction.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using Test
1111
eqs = [D(x) ~ σ*(y-x),
1212
D(y) ~ x*-z)-y,
1313
D(z) ~ x*y - β*z]
14-
de = DiffEqSystem(eqs,[t],[x,y,z],[σ,ρ,β])
14+
de = DiffEqSystem(eqs,t,[x,y,z],[σ,ρ,β])
1515
ModelingToolkit.generate_ode_function(de)
1616
ModelingToolkit.generate_ode_function(de;version=ModelingToolkit.SArrayFunction)
1717
jac_expr = ModelingToolkit.generate_ode_jacobian(de)
@@ -20,10 +20,11 @@ f = ODEFunction(de)
2020
ModelingToolkit.generate_ode_iW(de)
2121

2222
# Differential equation with automatic extraction of variables
23-
de2 = DiffEqSystem(eqs, [t])
23+
de2 = DiffEqSystem(eqs, t)
2424

2525
function test_vars_extraction(de, de2)
26-
for el in (:ivs, :dvs, :ps)
26+
@test de.iv == de2.iv
27+
for el in (:dvs, :ps)
2728
names2 = sort(collect(var.name for var in getfield(de2,el)))
2829
names = sort(collect(var.name for var in getfield(de,el)))
2930
@test names2 == names
@@ -37,7 +38,7 @@ test_vars_extraction(de, de2)
3738
@Unknown u(t) u_tt(t) u_t(t) x_t(t)
3839
eqs = [D3(u) ~ 2(D2(u)) + D(u) + D(x) + 1
3940
D2(x) ~ D(x) + 2]
40-
de = DiffEqSystem(eqs, [t])
41+
de = DiffEqSystem(eqs, t)
4142
de1 = ode_order_lowering(de)
4243
lowered_eqs = [D(u_tt) ~ 2u_tt + u_t + x_t + 1
4344
D(x_t) ~ x_t + 2
@@ -51,7 +52,7 @@ a = y - x
5152
eqs = [D(x) ~ σ*a,
5253
D(y) ~ x*-z)-y,
5354
D(z) ~ x*y - β*z]
54-
de = DiffEqSystem(eqs,[t],[x,y,z],[σ,ρ,β])
55+
de = DiffEqSystem(eqs,t,[x,y,z],[σ,ρ,β])
5556
ModelingToolkit.generate_ode_function(de)
5657
jac = ModelingToolkit.calculate_jacobian(de)
5758
f = ODEFunction(de)
@@ -75,8 +76,8 @@ ModelingToolkit.generate_nlsys_function(ns)
7576
_x = y / C
7677
eqs = [D(x) ~ -A*x,
7778
D(y) ~ A*x - B*_x]
78-
de = DiffEqSystem(eqs,[t],[x,y],[A,B,C])
79-
test_vars_extraction(de, DiffEqSystem(eqs,[t]))
79+
de = DiffEqSystem(eqs,t,[x,y],[A,B,C])
80+
test_vars_extraction(de, DiffEqSystem(eqs,t))
8081
test_vars_extraction(de, DiffEqSystem(eqs))
8182
@test eval(ModelingToolkit.generate_ode_function(de))([0.0,0.0],[1.0,2.0],[1,2,3],0.0) -1/3
8283

0 commit comments

Comments
 (0)