Skip to content

Commit 3685d82

Browse files
Fix DiffEqSystem
1 parent f57fe6d commit 3685d82

File tree

4 files changed

+37
-50
lines changed

4 files changed

+37
-50
lines changed

src/equations.jl

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,40 +12,31 @@ Base.:~(lhs::Expression, rhs::Number ) = Equation(lhs, rhs)
1212
Base.:~(lhs::Number , rhs::Expression) = Equation(lhs, rhs)
1313

1414

15-
function extract_elements(eqs, targetmap, default = nothing)
16-
elems = Dict{Symbol, Vector{Variable}}()
17-
names = Dict{Symbol, Set{Symbol}}()
18-
if default == nothing
19-
targets = unique(collect(values(targetmap)))
20-
else
21-
targets = [unique(collect(values(targetmap))), default]
22-
end
23-
for target in targets
24-
elems[target] = Vector{Variable}()
25-
names[target] = Set{Symbol}()
26-
end
27-
for eq in eqs
28-
extract_elements!(eq, elems, names, targetmap, default)
15+
_is_derivative(x::Variable) = x.diff !== nothing
16+
_is_dependent(x::Variable) = x.subtype === :DependentVariable && !isempty(x.dependents)
17+
_subtype(subtype::Symbol) = x -> x.subtype === subtype
18+
19+
function extract_elements(eqs, predicates)
20+
result = [Variable[] for p predicates]
21+
vars = foldl(vars!, eqs; init=Set{Variable}())
22+
23+
for var vars
24+
for (i, p) enumerate(predicates)
25+
p(var) && (push!(result[i], var); break)
26+
end
2927
end
30-
Tuple(elems[target] for target in targets)
28+
29+
return result
3130
end
32-
# Walk the tree recursively and push variables into the right set
33-
function extract_elements!(op, elems, names, targetmap, default)
31+
32+
function vars!(vars, op)
3433
args = isa(op, Equation) ? Expression[op.lhs, op.rhs] : op.args
3534

36-
for arg in args
37-
if arg isa Operation
38-
extract_elements!(arg, elems, names, targetmap, default)
39-
elseif arg isa Variable
40-
if default == nothing
41-
target = haskey(targetmap, arg.subtype) ? targetmap[arg.subtype] : continue
42-
else
43-
target = haskey(targetmap, arg.subtype) ? targetmap[arg.subtype] : default
44-
end
45-
if !in(arg.name, names[target])
46-
push!(names[target], arg.name)
47-
push!(elems[target], arg)
48-
end
49-
end
35+
for arg args
36+
isa(arg, Operation) ? vars!(vars, arg) :
37+
isa(arg, Variable) ? push!(vars, arg) :
38+
nothing
5039
end
40+
41+
return vars
5142
end

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,34 +19,31 @@ end
1919

2020
function DiffEqSystem(eqs; iv_name = :IndependentVariable,
2121
dv_name = :DependentVariable,
22-
v_name = :Variable,
2322
p_name = :Parameter)
24-
targetmap = Dict(iv_name => iv_name, dv_name => dv_name, v_name => v_name,
25-
p_name => p_name)
26-
ivs, dvs, vs, ps = extract_elements(eqs, targetmap)
27-
DiffEqSystem(eqs, ivs, dvs, vs, ps, iv_name, dv_name, p_name, Matrix{Expression}(0,0))
23+
predicates = [_is_derivative, _subtype(iv_name), _is_dependent, _subtype(dv_name), _subtype(p_name)]
24+
_, ivs, dvs, vs, ps = extract_elements(eqs, predicates)
25+
DiffEqSystem(eqs, ivs, dvs, vs, ps, iv_name, dv_name, p_name, Matrix{Expression}(undef,0,0))
2826
end
2927

3028
function DiffEqSystem(eqs, ivs;
3129
dv_name = :DependentVariable,
32-
v_name = :Variable,
3330
p_name = :Parameter)
34-
targetmap = Dict(dv_name => dv_name, v_name => v_name, p_name => p_name)
35-
dvs, vs, ps = extract_elements(eqs, targetmap)
31+
predicates = [_is_derivative, _is_dependent, _subtype(dv_name), _subtype(p_name)]
32+
_, dvs, vs, ps = extract_elements(eqs, predicates)
3633
DiffEqSystem(eqs, ivs, dvs, vs, ps, ivs[1].subtype, dv_name, p_name, Matrix{Expression}(undef,0,0))
3734
end
3835

3936
function generate_ode_function(sys::DiffEqSystem;version = ArrayFunction)
40-
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in 1:length(sys.dvs)]
41-
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in 1:length(sys.ps)]
37+
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in eachindex(sys.dvs)]
38+
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in eachindex(sys.ps)]
4239
sys_exprs = build_equals_expr.(sys.eqs)
4340
if version == ArrayFunction
44-
dvar_exprs = [:(du[$i] = $(Symbol("$(sys.dvs[i].name)_$(sys.ivs[1].name)"))) for i in 1:length(sys.dvs)]
41+
dvar_exprs = [:(du[$i] = $(Symbol("$(sys.dvs[i].name)_$(sys.ivs[1].name)"))) for i in eachindex(sys.dvs)]
4542
exprs = vcat(var_exprs,param_exprs,sys_exprs,dvar_exprs)
4643
block = expr_arr_to_block(exprs)
4744
:((du,u,p,t)->$(toexpr(block)))
4845
elseif version == SArrayFunction
49-
dvar_exprs = [:($(Symbol("$(sys.dvs[i].name)_$(sys.ivs[1].name)"))) for i in 1:length(sys.dvs)]
46+
dvar_exprs = [:($(Symbol("$(sys.dvs[i].name)_$(sys.ivs[1].name)"))) for i in eachindex(sys.dvs)]
5047
svector_expr = quote
5148
E = eltype(tuple($(dvar_exprs...)))
5249
T = StaticArrays.similar_type(typeof(u), E)
@@ -84,8 +81,8 @@ function calculate_jacobian(sys::DiffEqSystem, simplify=true)
8481
end
8582

8683
function generate_ode_jacobian(sys::DiffEqSystem, simplify=true)
87-
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in 1:length(sys.dvs)]
88-
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in 1:length(sys.ps)]
84+
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in eachindex(sys.dvs)]
85+
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in eachindex(sys.ps)]
8986
diff_exprs = filter(!isintermediate, sys.eqs)
9087
jac = calculate_jacobian(sys, simplify)
9188
sys.jac = jac
@@ -96,8 +93,8 @@ function generate_ode_jacobian(sys::DiffEqSystem, simplify=true)
9693
end
9794

9895
function generate_ode_iW(sys::DiffEqSystem, simplify=true)
99-
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in 1:length(sys.dvs)]
100-
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in 1:length(sys.ps)]
96+
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in eachindex(sys.dvs)]
97+
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in eachindex(sys.ps)]
10198
diff_exprs = filter(!isintermediate, sys.eqs)
10299
jac = sys.jac
103100

src/systems/nonlinear/nonlinear_system.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@ end
1515
function NonlinearSystem(eqs;
1616
v_name = :DependentVariable,
1717
p_name = :Parameter)
18-
targetmap = Dict(v_name => v_name, p_name => p_name)
19-
vs, ps = extract_elements(eqs, targetmap)
18+
vs, ps = extract_elements(eqs, [_subtype(v_name), _subtype(p_name)])
2019
NonlinearSystem(eqs, vs, ps, [v_name], p_name)
2120
end
2221

test/system_construction.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jac = ModelingToolkit.calculate_jacobian(de)
2020
f = ODEFunction(de)
2121
ModelingToolkit.generate_ode_iW(de)
2222

23-
# Differential equation with automatic extraction of variables on rhs
23+
# Differential equation with automatic extraction of variables
2424
de2 = DiffEqSystem(eqs, [t])
2525

2626
function test_vars_extraction(de, de2)

0 commit comments

Comments
 (0)