Skip to content

Commit 829e2e4

Browse files
Merge pull request #75 from JuliaDiffEq/hg/fix/simplify
Refactor simplification
2 parents fc4dffc + f4be6d8 commit 829e2e4

File tree

3 files changed

+69
-73
lines changed

3 files changed

+69
-73
lines changed

src/simplify.jl

Lines changed: 54 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,92 +1,73 @@
1-
function simplify_constants(O::Operation,shorten_tree = true)
2-
O_last = nothing
3-
_O = O
4-
while _O != O_last
5-
O_last = _O
6-
_O = _simplify_constants(_O,shorten_tree)
7-
if typeof(_O) <: Operation
8-
_O = Operation(_O.op,simplify_constants.(_O.args,shorten_tree))
1+
function simplify_constants(O::Operation, shorten_tree = true)
2+
while true
3+
O′ = _simplify_constants(O, shorten_tree)
4+
if is_operation(O′)
5+
O′ = Operation(O′.op, simplify_constants.(O′.args, shorten_tree))
96
end
7+
O == O′ && return O
8+
O = O′
109
end
11-
_O
1210
end
1311

14-
const TREE_SHRINK_OPS = [:*,:+]
12+
const AC_OPERATORS = (*, +)
1513

16-
function _simplify_constants(O,shorten_tree = true)
14+
function _simplify_constants(O, shorten_tree = true)
1715
# Tree shrinking
18-
if shorten_tree
19-
for cur_op in TREE_SHRINK_OPS
20-
if Symbol(O.op) == cur_op
21-
# Shrink tree
22-
if any(x->typeof(x)<:Operation && Symbol(x.op) == cur_op ,O.args)
23-
idxs = findall(x->typeof(x)<:Operation && Symbol(x.op) == cur_op,O.args)
24-
keep_idxs = 1:length(O.args) .∉ (idxs,)
25-
args = Vector{Expression}[O.args[i].args for i in idxs]
26-
push!(args,O.args[keep_idxs])
27-
return Operation(O.op,vcat(args...))
28-
# Collapse constants
29-
elseif length(findall(x->typeof(x)<:Variable && x.subtype == :Constant ,O.args)) > 1
30-
idxs = findall(x->typeof(x)<:Variable && x.subtype == :Constant ,O.args)
31-
other_idxs = 1:length(O.args) .∉ (idxs,)
32-
if cur_op == :*
33-
new_var = Constant(prod(x->x.value,O.args[idxs]))
34-
elseif cur_op == :+
35-
new_var = Constant(sum(x->x.value,O.args[idxs]))
36-
end
37-
new_args = O.args[other_idxs]
38-
push!(new_args,new_var)
39-
if length(new_args) > 1
40-
return Operation(O.op,new_args)
41-
else
42-
return new_args[1]
43-
end
44-
end
16+
if shorten_tree && O.op AC_OPERATORS
17+
# Flatten tree
18+
idxs = findall(x -> is_operation(x) && x.op === O.op, O.args)
19+
if !isempty(idxs)
20+
keep_idxs = eachindex(O.args) .∉ (idxs,)
21+
args = Vector{Expression}[O.args[i].args for i in idxs]
22+
push!(args, O.args[keep_idxs])
23+
return Operation(O.op, vcat(args...))
24+
end
25+
26+
# Collapse constants
27+
idxs = findall(is_constant, O.args)
28+
if length(idxs) > 1
29+
other_idxs = eachindex(O.args) .∉ (idxs,)
30+
new_const = Constant(mapreduce(get, O.op, O.args[idxs]))
31+
args = push!(O.args[other_idxs], new_const)
32+
33+
length(args) == 1 && return first(args)
34+
return Operation(O.op, args)
4535
end
46-
end
4736
end
4837

49-
if Symbol(O.op) == :*
38+
if O.op === (*)
5039
# If any variable is `Constant(0)`, zero the whole thing
40+
any(iszero, O.args) && return Constant(0)
41+
5142
# If any variable is `Constant(1)`, remove that `Constant(1)` unless
5243
# they are all `Constant(1)`, in which case simplify to a single variable
53-
if any(x->typeof(x)<:Variable && (isequal(x,Constant(0)) || isequal(x,Constant(-0))),O.args)
54-
return Constant(0)
55-
elseif any(x->typeof(x)<:Variable && isequal(x,Constant(1)),O.args)
56-
idxs = findall(x->typeof(x)<:Variable && isequal(x,Constant(1)),O.args)
57-
_O = Operation(O.op,O.args[1:length(O.args) .∉ (idxs,)])
58-
if isempty(_O.args)
59-
return Constant(1)
60-
elseif length(_O.args) == 1
61-
return _O.args[1]
62-
else
63-
return _O
64-
end
65-
else
66-
return O
67-
end
68-
elseif Symbol(O.op) == :+ && any(x->typeof(x)<:Variable &&
69-
(isequal(x,Constant(0)) || isequal(x,Constant(-0))),O.args)
70-
# If there are Constant(0)s in a big `+` expression, get rid of them
71-
idxs = findall(x->typeof(x)<:Variable && (isequal(x,Constant(0)) || isequal(x,Constant(-0))),O.args)
72-
_O = Operation(O.op,O.args[1:length(O.args) .∉ (idxs,)])
73-
if isempty(_O.args)
74-
return Constant(0)
75-
elseif length(_O.args) == 1
76-
return _O.args[1]
77-
else
78-
return O
44+
if any(isone, O.args)
45+
args = filter(!isone, O.args)
46+
47+
isempty(args) && return Constant(1)
48+
length(args) == 1 && return first(args)
49+
return Operation(O.op, args)
7950
end
80-
elseif O.op == identity
81-
return O.args[1]
82-
elseif Symbol(O.op) == :- && length(O.args) == 1
83-
return Operation(*,Expression[-1,O.args[1]])
84-
else
51+
8552
return O
8653
end
54+
55+
if O.op === (+) && any(iszero, O.args)
56+
# If there are Constant(0)s in a big `+` expression, get rid of them
57+
args = filter(!iszero, O.args)
58+
59+
isempty(args) && return Constant(0)
60+
length(args) == 1 && return first(args)
61+
return Operation(O.op, args)
62+
end
63+
64+
(O.op, length(O.args)) === (identity, 1) && return O.args[1]
65+
66+
(O.op, length(O.args)) === (-, 1) && return Operation(*, Expression[-1, O.args[1]])
67+
8768
return O
8869
end
89-
simplify_constants(x::Variable,y=false) = x
90-
_simplify_constants(x::Variable,y=false) = x
70+
simplify_constants(x::Variable, y=false) = x
71+
_simplify_constants(x::Variable, y=false) = x
9172

9273
export simplify_constants

src/utils.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@ end
1919

2020
toexpr(ex) = MacroTools.postwalk(x->x isa Union{Expression,Operation} ? Expr(x) : x, ex)
2121

22+
is_constant(x::Variable) = x.subtype === :Constant
23+
is_constant(::Any) = false
24+
25+
is_operation(::Operation) = true
26+
is_operation(::Any) = false
27+
2228
has_dependent(t::Variable) = Base.Fix2(has_dependent, t)
2329
has_dependent(x::Variable, t::Variable) =
2430
t x.dependents || any(has_dependent(t), x.dependents)

src/variables.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,15 @@ end
6363
export Variable,Parameter,Constant,DependentVariable,IndependentVariable,JumpVariable,NoiseVariable,
6464
@Var, @DVar, @IVar, @Param, @Const
6565

66+
67+
Base.get(x::Variable) = x.value
68+
69+
Base.iszero(::Expression) = false
70+
Base.iszero(c::Variable) = get(c) isa Number && iszero(get(c))
71+
Base.isone(::Expression) = false
72+
Base.isone(c::Variable) = get(c) isa Number && isone(get(c))
73+
74+
6675
# Variables use isequal for equality since == is an Operation
6776
function Base.:(==)(x::Variable,y::Variable)
6877
x.name == y.name && x.subtype == y.subtype && x.value == y.value &&

0 commit comments

Comments
 (0)