Skip to content

Commit 6c75a23

Browse files
Minor cleanup
1 parent c9642b4 commit 6c75a23

File tree

5 files changed

+14
-16
lines changed

5 files changed

+14
-16
lines changed

src/differentials.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ struct Differential <: Function
22
x::Expression
33
end
44

5+
Base.show(io::IO, D::Differential) = print(io, "(D'~", D.x, ")")
56
Base.convert(::Type{Expr}, D::Differential) = D
67

78
(D::Differential)(x::Operation) = Operation(D, Expression[x])

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,11 @@ using Base: RefValue
66

77
isintermediate(eq::Equation) = !(isa(eq.lhs, Operation) && isa(eq.lhs.op, Differential))
88

9-
function _unwrap_differenital(O)
10-
isa(O, Operation) || return (O, nothing, 0)
11-
isa(O.op, Differential) || return (O, nothing, 0)
12-
(x, t, order) = _unwrap_differenital(O.args[1])
13-
t === nothing && (t = O.op.x)
14-
t == O.op.x || throw(ArgumentError("non-matching differentials on lhs"))
9+
function flatten_differential(O::Operation)
10+
@assert is_derivative(O) "invalid differential: $O"
11+
is_derivative(O.args[1]) || return (O.args[1], O.op.x, 1)
12+
(x, t, order) = flatten_differential(O.args[1])
13+
t == O.op.x || throw(ArgumentError("non-matching differentials on lhs: $t, $(O.op.x)"))
1514
return (x, t, order + 1)
1615
end
1716

@@ -24,7 +23,7 @@ struct DiffEq # dⁿx/dtⁿ = rhs
2423
end
2524
function Base.convert(::Type{DiffEq}, eq::Equation)
2625
isintermediate(eq) && throw(ArgumentError("intermediate equation received"))
27-
(x, t, n) = _unwrap_differenital(eq.lhs)
26+
(x, t, n) = flatten_differential(eq.lhs)
2827
return DiffEq(x, t, n, eq.rhs)
2928
end
3029
Base.:(==)(a::DiffEq, b::DiffEq) = (a.x, a.t, a.n, a.rhs) == (b.x, b.t, b.n, b.rhs)

src/systems/diffeqs/first_order_transform.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
function lower_varname(var::Variable, idv, order = 0)
1+
function lower_varname(var::Variable, idv, order)
22
order == 0 && return var
33
name = Symbol(var.name, :_, string(idv.name)^order)
44
return Variable(name, var.known, var.dependents)
@@ -39,8 +39,8 @@ end
3939

4040
function rename(O::Expression)
4141
isa(O, Operation) || return O
42-
if isa(O.op, Differential)
43-
(x, t, order) = _unwrap_differenital(O)
42+
if is_derivative(O)
43+
(x, t, order) = flatten_differential(O)
4444
return lower_varname(x, t, order)
4545
end
4646
return Operation(O.op, rename.(O.args))

src/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ is_constant(::Any) = false
6060
is_operation(::Operation) = true
6161
is_operation(::Any) = false
6262

63+
is_derivative(O::Operation) = isa(O.op, Differential)
64+
is_derivative(::Any) = false
65+
6366
has_dependent(t::Variable) = Base.Fix2(has_dependent, t)
6467
has_dependent(x::Variable, t::Variable) =
6568
t x.dependents || any(has_dependent(t), x.dependents)

src/variables.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,7 @@ function Base.convert(::Type{Expr}, x::Variable)
3838
end
3939
Base.convert(::Type{Expr}, c::Constant) = c.value
4040

41-
function Base.show(io::IO, x::Variable)
42-
subtype = x.known ? :Parameter : :Unknown
43-
print(io, subtype, '(', repr(x.name))
44-
isempty(x.dependents) || print(io, ", ", x.dependents)
45-
print(io, ')')
46-
end
41+
Base.show(io::IO, x::Variable) = print(io, x.name)
4742

4843
# Build variables more easily
4944
function _parse_vars(macroname, fun, x)

0 commit comments

Comments
 (0)