Skip to content

Commit b9cc424

Browse files
Break redundancy in simplification into helper functions
1 parent 7e5da57 commit b9cc424

File tree

3 files changed

+22
-10
lines changed

3 files changed

+22
-10
lines changed

src/simplify.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,15 @@ function _simplify_constants(O,shorten_tree = true)
2525
args = Vector{Expression}[O.args[i].args for i in idxs]
2626
push!(args,O.args[keep_idxs])
2727
return Operation(O.op,vcat(args...))
28+
end
2829
# Collapse constants
29-
elseif length(findall(x->typeof(x)<:Variable && x.subtype == :Constant ,O.args)) > 1
30-
idxs = findall(x->typeof(x)<:Variable && x.subtype == :Constant ,O.args)
30+
idxs = findall(is_constant, O.args)
31+
if length(idxs) > 1
3132
other_idxs = 1:length(O.args) .∉ (idxs,)
3233
if cur_op == :*
33-
new_var = Constant(prod(x->x.value,O.args[idxs]))
34+
new_var = Constant(prod(get, O.args[idxs]))
3435
elseif cur_op == :+
35-
new_var = Constant(sum(x->x.value,O.args[idxs]))
36+
new_var = Constant(sum(get, O.args[idxs]))
3637
end
3738
new_args = O.args[other_idxs]
3839
push!(new_args,new_var)
@@ -50,10 +51,10 @@ function _simplify_constants(O,shorten_tree = true)
5051
# If any variable is `Constant(0)`, zero the whole thing
5152
# If any variable is `Constant(1)`, remove that `Constant(1)` unless
5253
# they are all `Constant(1)`, in which case simplify to a single variable
53-
if any(x->typeof(x)<:Variable && (isequal(x,Constant(0)) || isequal(x,Constant(-0))),O.args)
54+
if any(iszero, O.args)
5455
return Constant(0)
55-
elseif any(x->typeof(x)<:Variable && isequal(x,Constant(1)),O.args)
56-
idxs = findall(x->typeof(x)<:Variable && isequal(x,Constant(1)),O.args)
56+
elseif any(isone, O.args)
57+
idxs = findall(isone, O.args)
5758
_O = Operation(O.op,O.args[1:length(O.args) .∉ (idxs,)])
5859
if isempty(_O.args)
5960
return Constant(1)
@@ -65,10 +66,9 @@ function _simplify_constants(O,shorten_tree = true)
6566
else
6667
return O
6768
end
68-
elseif Symbol(O.op) == :+ && any(x->typeof(x)<:Variable &&
69-
(isequal(x,Constant(0)) || isequal(x,Constant(-0))),O.args)
69+
elseif Symbol(O.op) == :+ && any(iszero, O.args)
7070
# If there are Constant(0)s in a big `+` expression, get rid of them
71-
idxs = findall(x->typeof(x)<:Variable && (isequal(x,Constant(0)) || isequal(x,Constant(-0))),O.args)
71+
idxs = findall(iszero, O.args)
7272
_O = Operation(O.op,O.args[1:length(O.args) .∉ (idxs,)])
7373
if isempty(_O.args)
7474
return Constant(0)

src/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,6 @@ function flatten_expr!(x)
1818
end
1919

2020
toexpr(ex) = MacroTools.postwalk(x->x isa Union{Expression,Operation} ? Expr(x) : x, ex)
21+
22+
is_constant(x::Variable) = x.subtype === :Constant
23+
is_constant(::Any) = false

src/variables.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,15 @@ end
6363
export Variable,Parameter,Constant,DependentVariable,IndependentVariable,JumpVariable,NoiseVariable,
6464
@Var, @DVar, @IVar, @Param, @Const
6565

66+
67+
Base.get(x::Variable) = x.value
68+
69+
Base.iszero(::Expression) = false
70+
Base.iszero(c::Variable) = get(c) isa Number && iszero(get(c))
71+
Base.isone(::Expression) = false
72+
Base.isone(c::Variable) = get(c) isa Number && isone(get(c))
73+
74+
6675
# Variables use isequal for equality since == is an Operation
6776
function Base.:(==)(x::Variable,y::Variable)
6877
x.name == y.name && x.subtype == y.subtype && x.value == y.value &&

0 commit comments

Comments
 (0)