Skip to content

Commit 49edc23

Browse files
committed
more changes
1 parent a11cfa6 commit 49edc23

File tree

4 files changed

+23
-18
lines changed

4 files changed

+23
-18
lines changed

src/equations.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ struct Equation
1212
"""The expression on the right-hand side of the equation."""
1313
rhs
1414
end
15-
Base.:(==)(a::Equation, b::Equation) = isequal((a.lhs, a.rhs), (b.lhs, b.rhs))
15+
Base.:(==)(a::Equation, b::Equation) = all(isequal.((a.lhs, a.rhs), (b.lhs, b.rhs)))
1616
Base.hash(a::Equation, salt::UInt) = hash(a.lhs, hash(a.rhs, salt))
1717

1818
"""
@@ -38,6 +38,10 @@ Equation(x() - y(), 0)
3838
Base.:~(lhs::Num, rhs::Num) = Equation(value(lhs), value(rhs))
3939
Base.:~(lhs::Num, rhs::Number ) = Equation(value(lhs), value(rhs))
4040
Base.:~(lhs::Number , rhs::Num) = Equation(value(lhs), value(rhs))
41+
Base.:~(lhs::Symbolic, rhs::Symbolic) = Equation(value(lhs), value(rhs))
42+
Base.:~(lhs::Symbolic, rhs::Any ) = Equation(value(lhs), value(rhs))
43+
Base.:~(lhs::Any, rhs::Symbolic ) = Equation(value(lhs), value(rhs))
44+
Base.:~(lhs::Number , rhs::Num) = Equation(value(lhs), value(rhs))
4145

4246
struct ConstrainedEquation
4347
constraints

src/systems/diffeqs/first_order_transform.jl

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
function lower_varname(var::Variable, idv, order)
1+
function lower_varname(var::Term, idv, order)
22
order == 0 && return var
3-
name = Symbol(var.name, , string(idv.name)^order)
4-
return Variable{vartype(var)}(name)
3+
name = Symbol(nameof(var.op), , string(idv)^order)
4+
#name = Symbol(var.name, :ˍ, string(idv.name)^order)
5+
return Sym{symtype(var.op)}(name)(var.args[1])
56
end
67

7-
function flatten_differential(O::Operation)
8+
function flatten_differential(O::Term)
89
@assert is_derivative(O) "invalid differential: $O"
910
is_derivative(O.args[1]) || return (O.args[1], O.op.x, 1)
1011
(x, t, order) = flatten_differential(O.args[1])
@@ -24,12 +25,12 @@ function ode_order_lowering(sys::ODESystem)
2425
end
2526

2627
function ode_order_lowering(eqs, iv, states)
27-
var_order = OrderedDict{Variable,Int}()
28-
D = Differential(iv())
28+
var_order = OrderedDict{Any,Int}()
29+
D = Differential(iv)
2930
diff_eqs = Equation[]
30-
diff_vars = Variable[]
31+
diff_vars = []
3132
alge_eqs = Equation[]
32-
alge_vars = Variable[]
33+
alge_vars = []
3334

3435
for (i, (eq, ss)) enumerate(zip(eqs, states))
3536
if isequal(eq.lhs, Constant(0))
@@ -42,7 +43,7 @@ function ode_order_lowering(eqs, iv, states)
4243
var′ = lower_varname(var, iv, maxorder - 1)
4344
rhs′ = rename_lower_order(eq.rhs)
4445
push!(diff_vars, var′)
45-
push!(diff_eqs, D(var′(iv())) ~ rhs′)
46+
push!(diff_eqs, D(var′) ~ rhs′)
4647
end
4748
end
4849

@@ -52,8 +53,8 @@ function ode_order_lowering(eqs, iv, states)
5253
rvar = lower_varname(var, iv, o)
5354
push!(diff_vars, lvar)
5455

55-
rhs = rvar(iv())
56-
eq = Differential(iv())(lvar(iv())) ~ rhs
56+
rhs = rvar
57+
eq = Differential(iv)(lvar) ~ rhs
5758
push!(diff_eqs, eq)
5859
end
5960
end
@@ -62,11 +63,11 @@ function ode_order_lowering(eqs, iv, states)
6263
return (vcat(diff_eqs, alge_eqs), vcat(diff_vars, alge_vars))
6364
end
6465

65-
function rename_lower_order(O::Expression)
66-
isa(O, Operation) || return O
66+
function rename_lower_order(O)
67+
isa(O, Term) || return O
6768
if is_derivative(O)
6869
(x, t, order) = flatten_differential(O)
69-
return lower_varname(x.op, t.op, order)(x.args...)
70+
return lower_varname(x, t, order)
7071
end
71-
return Operation(O.op, rename_lower_order.(O.args))
72+
return Term(O.op, rename_lower_order.(O.args))
7273
end

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ is_constant(::Any) = false
5555
is_operation(::Operation) = true
5656
is_operation(::Any) = false
5757

58-
is_derivative(O::Operation) = isa(O.op, Differential)
58+
is_derivative(O::Term) = isa(O.op, Differential)
5959
is_derivative(::Any) = false
6060

6161
"""

test/odesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ lowered_eqs = [D(uˍtt) ~ 2uˍtt + uˍt + xˍt + 1
119119
@test de1 == ODESystem(lowered_eqs)
120120

121121
# issue #219
122-
@test de1.states == [ModelingToolkit.var_from_nested_derivative(eq.lhs)[1] for eq in de1.eqs] == ODESystem(lowered_eqs).states
122+
@test all(isequal.([ModelingToolkit.var_from_nested_derivative(eq.lhs)[1] for eq in de1.eqs], ODESystem(lowered_eqs).states))
123123

124124
test_diffeq_inference("first-order transform", de1, t, [uˍtt, xˍt, uˍt, u, x], [])
125125
du = zeros(5)

0 commit comments

Comments
 (0)