Skip to content

Commit bd482b1

Browse files
Use isequal instead of ==
1 parent 02edf04 commit bd482b1

13 files changed

+78
-70
lines changed

src/differentials.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ function (D::Differential)(x::Variable)
1212
return Operation(D, Expression[x])
1313
end
1414
(::Differential)(::Any) = Constant(0)
15-
Base.:(==)(D1::Differential, D2::Differential) = D1.x == D2.x
15+
Base.:(==)(D1::Differential, D2::Differential) = isequal(D1.x, D2.x)
1616

1717
function expand_derivatives(O::Operation)
1818
@. O.args = expand_derivatives(O.args)

src/equations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@ struct Equation
55
lhs::Expression
66
rhs::Expression
77
end
8-
Base.:(==)(a::Equation, b::Equation) = (a.lhs, a.rhs) == (b.lhs, b.rhs)
8+
Base.:(==)(a::Equation, b::Equation) = isequal((a.lhs, a.rhs), (b.lhs, b.rhs))
99

1010
Base.:~(lhs::Expression, rhs::Expression) = Equation(lhs, rhs)
1111
Base.:~(lhs::Expression, rhs::Number ) = Equation(lhs, rhs)
1212
Base.:~(lhs::Number , rhs::Expression) = Equation(lhs, rhs)
1313

1414

1515
_is_dependent(x::Variable) = !x.known && !isempty(x.dependents)
16-
_is_parameter(iv) = x -> x.known && x iv
16+
_is_parameter(iv) = x -> x.known && !isequal(x, iv)
1717
_is_known(x::Variable) = x.known
1818
_is_unknown(x::Variable) = !x.known
1919

src/function_registration.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,13 @@ for (M, f, arity) in DiffRules.diffrules()
3232
@eval @register $sig
3333
end
3434

35-
for fun = (:<, :>, :(==), :!, :&, :|, :div)
35+
for fun [:!]
36+
basefun = Expr(:., Base, QuoteNode(fun))
37+
sig = :($basefun(x))
38+
@eval @register $sig
39+
end
40+
41+
for fun [:<, :>, :(==), :&, :|, :div]
3642
basefun = Expr(:., Base, QuoteNode(fun))
3743
sig = :($basefun(x,y))
3844
@eval @register $sig

src/operations.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@ struct Operation <: Expression
44
end
55

66
# Recursive ==
7-
function Base.:(==)(x::Operation,y::Operation)
7+
function Base.isequal(x::Operation,y::Operation)
88
x.op == y.op && length(x.args) == length(y.args) && all(isequal.(x.args,y.args))
99
end
10-
Base.:(==)(::Operation, ::Number ) = false
11-
Base.:(==)(::Number , ::Operation) = false
12-
Base.:(==)(::Operation, ::Variable ) = false
13-
Base.:(==)(::Variable , ::Operation) = false
14-
Base.:(==)(::Operation, ::Constant ) = false
15-
Base.:(==)(::Constant , ::Operation) = false
10+
Base.isequal(::Operation, ::Number ) = false
11+
Base.isequal(::Number , ::Operation) = false
12+
Base.isequal(::Operation, ::Variable ) = false
13+
Base.isequal(::Variable , ::Operation) = false
14+
Base.isequal(::Operation, ::Constant ) = false
15+
Base.isequal(::Constant , ::Operation) = false
1616

1717
Base.convert(::Type{Expr}, O::Operation) =
1818
build_expr(:call, Any[Symbol(O.op); convert.(Expr, O.args)])

src/simplify.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ function simplify_constants(O::Operation, shorten_tree)
44
if is_operation(O′)
55
O′ = Operation(O′.op, simplify_constants.(O′.args, shorten_tree))
66
end
7-
O == O′ && return O
7+
isequal(O, O′) && return O
88
O = O′
99
end
1010
end

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ function flatten_differential(O::Operation)
1010
@assert is_derivative(O) "invalid differential: $O"
1111
is_derivative(O.args[1]) || return (O.args[1], O.op.x, 1)
1212
(x, t, order) = flatten_differential(O.args[1])
13-
t == O.op.x || throw(ArgumentError("non-matching differentials on lhs: $t, $(O.op.x)"))
13+
isequal(t, O.op.x) || throw(ArgumentError("non-matching differentials on lhs: $t, $(O.op.x)"))
1414
return (x, t, order + 1)
1515
end
1616

@@ -26,7 +26,7 @@ function Base.convert(::Type{DiffEq}, eq::Equation)
2626
(x, t, n) = flatten_differential(eq.lhs)
2727
return DiffEq(x, t, n, eq.rhs)
2828
end
29-
Base.:(==)(a::DiffEq, b::DiffEq) = (a.x, a.t, a.n, a.rhs) == (b.x, b.t, b.n, b.rhs)
29+
Base.:(==)(a::DiffEq, b::DiffEq) = isequal((a.x, a.t, a.n, a.rhs), (b.x, b.t, b.n, b.rhs))
3030
get_args(eq::DiffEq) = Expression[eq.x, eq.t, eq.rhs]
3131

3232
struct DiffEqSystem <: AbstractSystem

src/systems/diffeqs/first_order_transform.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ function ode_order_lowering(eqs, iv)
1717
var, maxorder = eq.x, eq.n
1818
if maxorder > get(var_order, var, 0)
1919
var_order[var] = maxorder
20-
var vars || push!(vars, var)
20+
any(isequal(var), vars) || push!(vars, var)
2121
end
2222
var′ = lower_varname(eq.x, eq.t, eq.n - 1)
2323
rhs′ = rename(eq.rhs)

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,4 +65,4 @@ is_derivative(::Any) = false
6565

6666
has_dependent(t::Variable) = Base.Fix2(has_dependent, t)
6767
has_dependent(x::Variable, t::Variable) =
68-
t x.dependents || any(has_dependent(t), x.dependents)
68+
any(isequal(t), x.dependents) || any(has_dependent(t), x.dependents)

src/variables.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@ Base.isone(ex::Expression) = isa(ex, Constant) && isone(ex.value)
2222

2323

2424
# Variables use isequal for equality since == is an Operation
25-
Base.:(==)(x::Variable, y::Variable) = (x.name, x.known) == (y.name, y.known)
26-
Base.:(==)(::Variable, ::Number) = false
27-
Base.:(==)(::Number, ::Variable) = false
28-
Base.:(==)(::Variable, ::Constant) = false
29-
Base.:(==)(::Constant, ::Variable) = false
30-
Base.:(==)(c::Constant, n::Number) = c.value == n
31-
Base.:(==)(n::Number, c::Constant) = c.value == n
32-
Base.:(==)(a::Constant, b::Constant) = a.value == b.value
25+
Base.isequal(x::Variable, y::Variable) = (x.name, x.known) == (y.name, y.known)
26+
Base.isequal(::Variable, ::Number) = false
27+
Base.isequal(::Number, ::Variable) = false
28+
Base.isequal(::Variable, ::Constant) = false
29+
Base.isequal(::Constant, ::Variable) = false
30+
Base.isequal(c::Constant, n::Number) = c.value == n
31+
Base.isequal(n::Number, c::Constant) = c.value == n
32+
Base.isequal(a::Constant, b::Constant) = a.value == b.value
3333

3434
function Base.convert(::Type{Expr}, x::Variable)
3535
x.known || return x.name

test/derivatives.jl

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,52 +6,52 @@ using Test
66
@Unknown x(t) y(t) z(t)
77
@Deriv D'~t D2''~t
88

9-
@test expand_derivatives(D(t)) == 1
10-
@test expand_derivatives(D(D(t))) == 0
9+
@test isequal(expand_derivatives(D(t)), 1)
10+
@test isequal(expand_derivatives(D(D(t))), 0)
1111

1212
dsin = D(sin(t))
13-
@test expand_derivatives(dsin) == cos(t)
13+
@test isequal(expand_derivatives(dsin), cos(t))
1414

1515
dcsch = D(csch(t))
16-
@test expand_derivatives(dcsch) == simplify_constants(coth(t) * csch(t) * -1)
16+
@test isequal(expand_derivatives(dcsch), simplify_constants(coth(t) * csch(t) * -1))
1717

18-
@test expand_derivatives(D(-7)) == 0
19-
@test expand_derivatives(D(sin(2t))) == simplify_constants(cos(2t) * 2)
20-
@test expand_derivatives(D2(sin(t))) == simplify_constants(-sin(t))
21-
@test expand_derivatives(D2(sin(2t))) == simplify_constants(sin(2t) * -4)
22-
@test expand_derivatives(D2(t)) == 0
23-
@test expand_derivatives(D2(5)) == 0
18+
@test isequal(expand_derivatives(D(-7)), 0)
19+
@test isequal(expand_derivatives(D(sin(2t))), simplify_constants(cos(2t) * 2))
20+
@test isequal(expand_derivatives(D2(sin(t))), simplify_constants(-sin(t)))
21+
@test isequal(expand_derivatives(D2(sin(2t))), simplify_constants(sin(2t) * -4))
22+
@test isequal(expand_derivatives(D2(t)), 0)
23+
@test isequal(expand_derivatives(D2(5)), 0)
2424

2525
# Chain rule
2626
dsinsin = D(sin(sin(t)))
27-
@test expand_derivatives(dsinsin) == cos(sin(t))*cos(t)
27+
@test isequal(expand_derivatives(dsinsin), cos(sin(t))*cos(t))
2828

2929
d1 = D(sin(t)*t)
3030
d2 = D(sin(t)*cos(t))
31-
@test expand_derivatives(d1) == t*cos(t)+sin(t)
32-
@test expand_derivatives(d2) == simplify_constants(cos(t)*cos(t)+sin(t)*(-1*sin(t)))
31+
@test isequal(expand_derivatives(d1), t*cos(t)+sin(t))
32+
@test isequal(expand_derivatives(d2), simplify_constants(cos(t)*cos(t)+sin(t)*(-1*sin(t))))
3333

3434
eqs = [0 ~ σ*(y-x),
3535
0 ~ x*-z)-y,
3636
0 ~ x*y - β*z]
3737
sys = NonlinearSystem(eqs,[x,y,z],[σ,ρ,β])
3838
jac = calculate_jacobian(sys)
39-
@test jac[1,1] == σ*-1
40-
@test jac[1,2] == σ
41-
@test jac[1,3] == 0
42-
@test jac[2,1] == ρ-z
43-
@test jac[2,2] == -1
44-
@test jac[2,3] == x*-1
45-
@test jac[3,1] == y
46-
@test jac[3,2] == x
47-
@test jac[3,3] == -1*β
39+
@test isequal(jac[1,1], σ*-1)
40+
@test isequal(jac[1,2], σ)
41+
@test isequal(jac[1,3], 0)
42+
@test isequal(jac[2,1], ρ-z)
43+
@test isequal(jac[2,2], -1)
44+
@test isequal(jac[2,3], x*-1)
45+
@test isequal(jac[3,1], y)
46+
@test isequal(jac[3,2], x)
47+
@test isequal(jac[3,3], -1*β)
4848

4949
# Variable dependence checking in differentiation
5050
@Unknown a(t) b(a)
51-
@test D(b) 0
51+
@test !isequal(D(b), 0)
5252

53-
@test expand_derivatives(D(x * y)) == simplify_constants(y*D(x) + x*D(y))
54-
@test_broken expand_derivatives(D(x * y)) == simplify_constants(D(x)*y + x*D(y))
53+
@test isequal(expand_derivatives(D(x * y)), simplify_constants(y*D(x) + x*D(y)))
54+
@test_broken isequal(expand_derivatives(D(x * y)), simplify_constants(D(x)*y + x*D(y)))
5555

56-
@test expand_derivatives(D(2t)) == 2
57-
@test expand_derivatives(D(2x)) == 2D(x)
56+
@test isequal(expand_derivatives(D(2t)), 2)
57+
@test isequal(expand_derivatives(D(2x)), 2D(x))

0 commit comments

Comments
 (0)