Skip to content

Commit 873400a

Browse files
committed
fix reduction
1 parent 5210eef commit 873400a

File tree

7 files changed

+53
-41
lines changed

7 files changed

+53
-41
lines changed

src/equations.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ end
1515
Base.:(==)(a::Equation, b::Equation) = all(isequal.((a.lhs, a.rhs), (b.lhs, b.rhs)))
1616
Base.hash(a::Equation, salt::UInt) = hash(a.lhs, hash(a.rhs, salt))
1717

18+
SymbolicUtils.simplify(x::Equation; kw...) = simplify(x.lhs; kw...) ~ simplify(x.rhs; kw...)
19+
1820
"""
1921
$(TYPEDSIGNATURES)
2022
@@ -41,7 +43,6 @@ Base.:~(lhs::Number , rhs::Num) = Equation(value(lhs), value(rhs))
4143
Base.:~(lhs::Symbolic, rhs::Symbolic) = Equation(value(lhs), value(rhs))
4244
Base.:~(lhs::Symbolic, rhs::Any ) = Equation(value(lhs), value(rhs))
4345
Base.:~(lhs::Any, rhs::Symbolic ) = Equation(value(lhs), value(rhs))
44-
Base.:~(lhs::Number , rhs::Num) = Equation(value(lhs), value(rhs))
4546

4647
struct ConstrainedEquation
4748
constraints

src/solve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ end
7272
function _solve(A, b)
7373
A = SymbolicUtils.simplify.(to_symbolic.(A), polynorm=true)
7474
b = SymbolicUtils.simplify.(to_symbolic.(b), polynorm=true)
75-
map(to_mtk, SymbolicUtils.simplify.(ldiv(sym_lu(A), b)))
75+
SymbolicUtils.simplify.(ldiv(sym_lu(A), b))
7676
end
7777

7878
LinearAlgebra.:(\)(A::AbstractMatrix{<:Expression}, b::AbstractVector{<:Expression}) = _solve(A, b)

src/systems/abstractsystem.jl

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ Generate a function to evaluate the system's equations.
117117
"""
118118
function generate_function end
119119

120+
getname(x::Sym) = nameof(x)
121+
getname(t::Term) = t.op isa Sym ? getname(t.op) : error("Cannot get name of $t")
122+
120123
function Base.getproperty(sys::AbstractSystem, name::Symbol)
121124

122125
if name fieldnames(typeof(sys))
@@ -128,27 +131,23 @@ function Base.getproperty(sys::AbstractSystem, name::Symbol)
128131
end
129132
end
130133

131-
i = findfirst(x->x.name==name,sys.states)
134+
i = findfirst(x->getname(x) == name, sys.states)
135+
132136
if i !== nothing
133-
x = rename(sys.states[i],renamespace(sys.name,name))
134-
if :iv fieldnames(typeof(sys))
135-
return x(getfield(sys,:iv)())
136-
else
137-
return x()
138-
end
137+
return rename(sys.states[i],renamespace(sys.name,name))
139138
end
140139

141140
if :ps fieldnames(typeof(sys))
142-
i = findfirst(x->x.name==name,sys.ps)
141+
i = findfirst(x->getname(x) == name,sys.ps)
143142
if i !== nothing
144-
return rename(sys.ps[i],renamespace(sys.name,name))()
143+
return rename(sys.ps[i],renamespace(sys.name,name))
145144
end
146145
end
147146

148147
if :observed fieldnames(typeof(sys))
149-
i = findfirst(x->convert(Variable,x.lhs).name==name,sys.observed)
148+
i = findfirst(x->getname(x.lhs)==name,sys.observed)
150149
if i !== nothing
151-
return rename(convert(Variable,sys.observed[i].lhs),renamespace(sys.name,name))(getfield(sys,:iv)())
150+
return rename(sys.observed[i].lhs,renamespace(sys.name,name))
152151
end
153152
end
154153

@@ -172,19 +171,23 @@ end
172171
namespace_equations(sys::AbstractSystem) = namespace_equation.(equations(sys),sys.name,sys.iv.name)
173172

174173
function namespace_equation(eq::Equation,name,ivname)
175-
_lhs = namespace_operation(eq.lhs,name,ivname)
176-
_rhs = namespace_operation(eq.rhs,name,ivname)
174+
_lhs = namespace_expr(eq.lhs,name,ivname)
175+
_rhs = namespace_expr(eq.rhs,name,ivname)
177176
_lhs ~ _rhs
178177
end
179178

180-
function namespace_operation(O::Operation,name,ivname)
181-
if O.op isa Sym && O.op.name != ivname
182-
Operation(rename(O.op,renamespace(name,O.op.name)),namespace_operation.(O.args,name,ivname))
179+
function namespace_expr(O::Sym,name,ivname)
180+
O.name == ivname ? O : rename(O,renamespace(name,O.name))
181+
end
182+
183+
function namespace_expr(O::Term,name,ivname)
184+
if O.op isa Sym
185+
Term(rename(O.op,renamespace(name,O.op.name)),namespace_expr.(O.args,name,ivname))
183186
else
184-
Operation(O.op,namespace_operation.(O.args,name,ivname))
187+
Term(O.op,namespace_expr.(O.args,name,ivname))
185188
end
186189
end
187-
namespace_operation(O::Constant,name,ivname) = O
190+
namespace_expr(O,name,ivname) = O
188191

189192
independent_variable(sys::AbstractSystem) = sys.iv
190193
states(sys::AbstractSystem) = unique(isempty(sys.systems) ? setdiff(sys.states, convert.(Variable,sys.pins)) : [sys.states;reduce(vcat,namespace_variables.(sys.systems))])

src/systems/diffeqs/odesystem.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ struct ODESystem <: AbstractODESystem
3131
states::Vector
3232
"""Parameter variables."""
3333
ps::Vector
34-
pins::Vector{Variable}
34+
pins::Vector{Num}
3535
observed::Vector{Equation}
3636
"""
3737
Time-derivative matrix. Note: this field will not be defined until
@@ -114,8 +114,6 @@ function ODESystem(eqs, iv=nothing; kwargs...)
114114
break
115115
end
116116
end
117-
else
118-
iv = convert(Variable, iv)
119117
end
120118
iv === nothing && throw(ArgumentError("Please pass in independent variables."))
121119
for eq in eqs

src/systems/reduction.jl

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,46 +34,55 @@ function make_lhs_0(eq)
3434
0 ~ eq.lhs - eq.rhs
3535
end
3636
end
37+
isvar(s::Sym) = !isparameter(s)
38+
isvar(s::Term) = isvar(s.op)
39+
isvar(s::Any) = false
40+
41+
function filterexpr(f, s)
42+
vs = []
43+
Rewriters.Prewalk(Rewriters.Chain([@rule((~x::f) => push!(vs, ~x))]))(s)
44+
vs
45+
end
3746

3847
function alias_elimination(sys::ODESystem)
3948
eqs = vcat(equations(sys), observed(sys))
4049

4150
# make all algebraic equations have 0 on LHS
4251
eqs = map(eqs) do eq
43-
if eq.lhs isa Operation && eq.lhs.op isa Differential
52+
if eq.lhs isa Term && eq.lhs.op isa Differential
4453
eq
4554
else
4655
make_lhs_0(eq)
4756
end
4857
end
4958

50-
new_stateops = map(eqs) do eq
51-
if eq.lhs isa Operation && eq.lhs.op isa Differential
52-
get_variables(eq.lhs)
59+
newstates = map(eqs) do eq
60+
if eq.lhs isa Term && eq.lhs.op isa Differential
61+
filterexpr(isvar, eq.lhs)
5362
else
5463
[]
5564
end
5665
end |> Iterators.flatten |> collect |> unique
5766

67+
5868
all_vars = map(eqs) do eq
59-
filter(x->!isparameter(x.op), get_variables(eq.rhs))
69+
@show eq.rhs
70+
@show filterexpr(isvar, eq.rhs)
71+
filterexpr(isvar, eq.rhs)
6072
end |> Iterators.flatten |> collect |> unique
6173

62-
newstates = convert.(Variable, new_stateops)
63-
64-
65-
alg_idxs = findall(x->x.lhs isa Constant && iszero(x.lhs), eqs)
66-
67-
eliminate = setdiff(convert.(Variable, all_vars), newstates)
74+
alg_idxs = findall(x->!(x.lhs isa Term) && iszero(x.lhs), eqs)
75+
@show all_vars, newstates
6876

69-
vars = map(x->x(sys.iv()), eliminate)
77+
eliminate = setdiff(all_vars, newstates)
78+
@show eliminate
7079

71-
outputs = solve_for(eqs[alg_idxs], vars)
80+
outputs = solve_for(eqs[alg_idxs], eliminate)
7281

7382
diffeqs = eqs[setdiff(1:length(eqs), alg_idxs)]
7483

75-
diffeqs′ = substitute_aliases(diffeqs, Dict(vars .=> outputs))
84+
diffeqs′ = substitute_aliases(diffeqs, Dict(eliminate .=> outputs))
7685

77-
ODESystem(diffeqs′, sys.iv(), new_stateops, parameters(sys), observed=vars .~ outputs)
86+
ODESystem(diffeqs′, sys.iv, newstates, parameters(sys), observed=eliminate .~ outputs)
7887
end
7988

src/variables.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ function Variable(name, indices...)
6363
end
6464

6565
rename(x::Sym{T},name) where T = Sym{T}(name)
66+
rename(x::Term, name) where T = x.op isa Sym ? rename(x.op, name)(x.args...) : error("can't rename $x to $name")
6667

6768
"""
6869
$(TYPEDEF)

test/reduction.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ flattened_system = ModelingToolkit.flatten(connected)
6161

6262
aliased_flattened_system = alias_elimination(flattened_system)
6363

64-
@test states(aliased_flattened_system) == convert.(Variable, [
64+
@test isequal(states(aliased_flattened_system), [
6565
lorenz1.x
6666
lorenz1.y
6767
lorenz1.z
@@ -70,15 +70,15 @@ aliased_flattened_system = alias_elimination(flattened_system)
7070
lorenz2.z
7171
])
7272

73-
@test setdiff(parameters(aliased_flattened_system), convert.(Variable, [
73+
@test setdiff(parameters(aliased_flattened_system), [
7474
lorenz1.σ
7575
lorenz1.ρ
7676
lorenz1.β
7777
lorenz1.F
7878
lorenz2.F
7979
lorenz2.ρ
8080
lorenz2.β
81-
])) |> isempty
81+
]) |> isempty
8282

8383
test_equal.(equations(aliased_flattened_system), [
8484
D(lorenz1.x) ~ lorenz1.σ*(lorenz1.y-lorenz1.x) + lorenz2.x - lorenz2.y - lorenz2.z,

0 commit comments

Comments
 (0)