Skip to content

Commit 8bd3c34

Browse files
Move Constant to separate datatype
1 parent 6061f03 commit 8bd3c34

File tree

6 files changed

+42
-56
lines changed

6 files changed

+42
-56
lines changed

src/ModelingToolkit.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ include("domains.jl")
1616
include("variables.jl")
1717

1818
Base.promote_rule(::Type{T},::Type{T2}) where {T<:Number,T2<:Expression} = Expression
19-
Base.one(::Type{T}) where T<:Expression = Constant(1)
20-
Base.zero(::Type{T}) where T<:Expression = Constant(0)
19+
Base.zero(::Type{<:Expression}) = Constant(0)
20+
Base.one(::Type{<:Expression}) = Constant(1)
2121
Base.convert(::Type{Variable},x::Int64) = Constant(x)
2222

2323
function caclulate_jacobian end

src/differentials.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ function expand_derivatives(O::Operation)
3131

3232
return O
3333
end
34-
expand_derivatives(x::Variable) = x
34+
expand_derivatives(x) = x
3535

3636
# Don't specialize on the function here
3737
function Derivative(O::Operation,idx)

src/operations.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@ end
88
function Base.:(==)(x::Operation,y::Operation)
99
x.op == y.op && length(x.args) == length(y.args) && all(isequal.(x.args,y.args))
1010
end
11-
Base.:(==)(x::Operation,y::Number) = false
12-
Base.:(==)(x::Number,y::Operation) = false
13-
Base.:(==)(x::Operation,y::Nothing) = false
14-
Base.:(==)(x::Nothing,y::Operation) = false
15-
Base.:(==)(x::Variable,y::Operation) = false
16-
Base.:(==)(x::Operation,y::Variable) = false
11+
Base.:(==)(::Operation, ::Number ) = false
12+
Base.:(==)(::Number , ::Operation) = false
13+
Base.:(==)(::Operation, ::Variable ) = false
14+
Base.:(==)(::Variable , ::Operation) = false
15+
Base.:(==)(::Operation, ::Constant ) = false
16+
Base.:(==)(::Constant , ::Operation) = false
1717

1818
Base.convert(::Type{Expr}, O::Operation) =
1919
build_expr(:call, Any[Symbol(O.op); convert.(Expr, O.args)])
@@ -36,8 +36,8 @@ function find_replace!(O::Operation,x::Variable,y::Expression)
3636
end
3737

3838
# For inv
39-
Base.convert(::Type{Operation},x::Int) = Operation(identity,Expression[Constant(x)])
40-
Base.convert(::Type{Operation},x::Bool) = Operation(identity,Expression[Constant(x)])
41-
Base.convert(::Type{Operation},x::Variable) = Operation(identity,Expression[x])
42-
Operation(x) = convert(Operation,x)
43-
Operation(x::Operation) = x
39+
Base.convert(::Type{Operation}, x::Int) = Operation(identity, Expression[Constant(x)])
40+
Base.convert(::Type{Operation}, x::Bool) = Operation(identity, Expression[Constant(x)])
41+
Base.convert(::Type{Operation}, x::Operation) = x
42+
Base.convert(::Type{Operation}, x::Expression) = Operation(identity, Expression[x])
43+
Operation(x) = convert(Operation, x)

src/simplify.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
function simplify_constants(O::Operation, shorten_tree = true)
1+
function simplify_constants(O::Operation, shorten_tree)
22
while true
33
O′ = _simplify_constants(O, shorten_tree)
44
if is_operation(O′)
@@ -8,10 +8,13 @@ function simplify_constants(O::Operation, shorten_tree = true)
88
O = O′
99
end
1010
end
11+
simplify_constants(x, shorten_tree) = x
12+
simplify_constants(x) = simplify_constants(x, true)
13+
1114

1215
const AC_OPERATORS = (*, +)
1316

14-
function _simplify_constants(O, shorten_tree = true)
17+
function _simplify_constants(O::Operation, shorten_tree)
1518
# Tree shrinking
1619
if shorten_tree && O.op AC_OPERATORS
1720
# Flatten tree
@@ -67,7 +70,7 @@ function _simplify_constants(O, shorten_tree = true)
6770

6871
return O
6972
end
70-
simplify_constants(x::Variable, y=false) = x
71-
_simplify_constants(x::Variable, y=false) = x
73+
_simplify_constants(x, shorten_tree) = x
74+
_simplify_constants(x) = _simplify_constants(x, true)
7275

7376
export simplify_constants

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ end
3333

3434
toexpr(ex) = MacroTools.postwalk(x -> isa(x, Expression) ? convert(Expr, x) : x, ex)
3535

36-
is_constant(x::Variable) = x.subtype === :Constant
36+
is_constant(::Constant) = true
3737
is_constant(::Any) = false
3838

3939
is_operation(::Operation) = true

src/variables.jl

Lines changed: 20 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ Variable(name,x::Variable) = Variable(name,x.value,x.value_type,
3131
x.size,x.context)
3232

3333
Parameter(name,args...;kwargs...) = Variable(name,args...;subtype=:Parameter,kwargs...)
34-
Constant(value::Number) = Variable(Symbol(value),value,typeof(value);subtype=:Constant)
35-
Constant(name,args...;kwargs...) = Variable(name,args...;subtype=:Constant,kwargs...)
3634
IndependentVariable(name,args...;kwargs...) = Variable(name,args...;subtype=:IndependentVariable,kwargs...)
3735

3836
function DependentVariable(name,args...;dependents = [],kwargs...)
@@ -64,53 +62,38 @@ export Variable,Parameter,Constant,DependentVariable,IndependentVariable,JumpVar
6462
@Var, @DVar, @IVar, @Param, @Const
6563

6664

67-
Base.get(x::Variable) = x.value
65+
struct Constant <: Expression
66+
value::Number
67+
end
68+
Base.get(c::Constant) = c.value
69+
6870

69-
Base.iszero(::Expression) = false
70-
Base.iszero(c::Variable) = get(c) isa Number && iszero(get(c))
71-
Base.isone(::Expression) = false
72-
Base.isone(c::Variable) = get(c) isa Number && isone(get(c))
71+
Base.iszero(ex::Expression) = isa(ex, Constant) && iszero(ex.value)
72+
Base.isone(ex::Expression) = isa(ex, Constant) && isone(ex.value)
7373

7474

7575
# Variables use isequal for equality since == is an Operation
7676
function Base.:(==)(x::Variable,y::Variable)
7777
x.name == y.name && x.subtype == y.subtype && x.value == y.value &&
7878
x.value_type == y.value_type && x.diff == y.diff
7979
end
80-
81-
function Base.:(==)(x::Variable,y::Number)
82-
x == Constant(y)
83-
end
84-
85-
function Base.:(==)(x::Number,y::Variable)
86-
Constant(x) == y
87-
end
80+
Base.:(==)(::Variable, ::Number) = false
81+
Base.:(==)(::Number, ::Variable) = false
82+
Base.:(==)(::Variable, ::Constant) = false
83+
Base.:(==)(::Constant, ::Variable) = false
84+
Base.:(==)(c::Constant, n::Number) = c.value == n
85+
Base.:(==)(n::Number, c::Constant) = c.value == n
86+
Base.:(==)(a::Constant, b::Constant) = a.value == b.value
8887

8988
function Base.convert(::Type{Expr}, x::Variable)
90-
if x.subtype == :Constant
91-
return x.value
92-
elseif x.diff == nothing
93-
return :($(x.name))
94-
else
95-
return :($(Symbol("$(x.name)_$(x.diff.x.name)")))
96-
end
89+
x.diff === nothing && return x.name
90+
return Symbol("$(x.name)_$(x.diff.x.name)")
9791
end
92+
Base.convert(::Type{Expr}, c::Constant) = c.value
9893

99-
function Base.show(io::IO, A::Variable)
100-
if A.subtype == :Constant
101-
print(io,"Constant($(A.value))")
102-
else
103-
str = "$(A.subtype)($(A.name))"
104-
if A.value != nothing
105-
str *= ", value = " * string(A.value)
106-
end
107-
108-
if A.diff != nothing
109-
str *= ", diff = " * string(A.diff)
110-
end
111-
112-
print(io,str)
113-
end
94+
function Base.show(io::IO, x::Variable)
95+
print(io, x.subtype, '(', x.name, ')')
96+
x.diff === nothing || print(io, ", diff = ", x.diff)
11497
end
11598

11699
extract_idv(eq) = eq.args[1].diff.x

0 commit comments

Comments
 (0)