Skip to content

Commit e5560a7

Browse files
Clarify control flow
1 parent fba101d commit e5560a7

File tree

1 file changed

+46
-63
lines changed

1 file changed

+46
-63
lines changed

src/simplify.jl

Lines changed: 46 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
function simplify_constants(O::Operation,shorten_tree = true)
1+
function simplify_constants(O::Operation, shorten_tree = true)
22
O_last = nothing
33
_O = O
44
while _O != O_last
@@ -11,79 +11,62 @@ function simplify_constants(O::Operation,shorten_tree = true)
1111
_O
1212
end
1313

14-
const TREE_SHRINK_OPS = [*, +]
14+
const AC_OPERATORS = [*, +]
1515

16-
function _simplify_constants(O,shorten_tree = true)
16+
function _simplify_constants(O, shorten_tree = true)
1717
# Tree shrinking
18-
if shorten_tree
19-
for cur_op in TREE_SHRINK_OPS
20-
if O.op == cur_op
21-
# Shrink tree
22-
if any(x -> is_operation(x) && x.op == cur_op, O.args)
23-
idxs = findall(x -> is_operation(x) && x.op == cur_op, O.args)
24-
keep_idxs = 1:length(O.args) .∉ (idxs,)
25-
args = Vector{Expression}[O.args[i].args for i in idxs]
26-
push!(args,O.args[keep_idxs])
27-
return Operation(O.op, vcat(args...))
28-
end
29-
# Collapse constants
30-
idxs = findall(is_constant, O.args)
31-
if length(idxs) > 1
32-
other_idxs = 1:length(O.args) .∉ (idxs,)
33-
if cur_op == (*)
34-
new_var = Constant(prod(get, O.args[idxs]))
35-
elseif cur_op == (+)
36-
new_var = Constant(sum(get, O.args[idxs]))
37-
end
38-
new_args = O.args[other_idxs]
39-
push!(new_args,new_var)
40-
if length(new_args) > 1
41-
return Operation(O.op,new_args)
42-
else
43-
return new_args[1]
44-
end
45-
end
46-
end
18+
if shorten_tree && O.op AC_OPERATORS
19+
# Flatten tree
20+
idxs = findall(x -> is_operation(x) && x.op == O.op, O.args)
21+
if !isempty(idxs)
22+
keep_idxs = eachindex(O.args) .∉ Ref(idxs)
23+
args = Vector{Expression}[O.args[i].args for i in idxs]
24+
push!(args, O.args[keep_idxs])
25+
return Operation(O.op, vcat(args...))
26+
end
27+
28+
# Collapse constants
29+
idxs = findall(is_constant, O.args)
30+
if length(idxs) > 1
31+
other_idxs = eachindex(O.args) .∉ (idxs,)
32+
new_var = Constant(mapreduce(get, O.op, O.args[idxs]))
33+
new_args = O.args[other_idxs]
34+
push!(new_args,new_var)
35+
36+
return length(new_args) > 1 ? Operation(O.op, new_args) : first(new_args)
4737
end
4838
end
4939

50-
if O.op == (*)
40+
if O.op === (*)
5141
# If any variable is `Constant(0)`, zero the whole thing
42+
any(iszero, O.args) && return Constant(0)
43+
5244
# If any variable is `Constant(1)`, remove that `Constant(1)` unless
5345
# they are all `Constant(1)`, in which case simplify to a single variable
54-
if any(iszero, O.args)
55-
return Constant(0)
56-
elseif any(isone, O.args)
57-
idxs = findall(isone, O.args)
58-
_O = Operation(O.op,O.args[1:length(O.args) .∉ (idxs,)])
59-
if isempty(_O.args)
60-
return Constant(1)
61-
elseif length(_O.args) == 1
62-
return _O.args[1]
63-
else
64-
return _O
65-
end
66-
else
67-
return O
68-
end
69-
elseif O.op == (+) && any(iszero, O.args)
70-
# If there are Constant(0)s in a big `+` expression, get rid of them
71-
idxs = findall(iszero, O.args)
72-
_O = Operation(O.op,O.args[1:length(O.args) .∉ (idxs,)])
73-
if isempty(_O.args)
74-
return Constant(0)
75-
elseif length(_O.args) == 1
76-
return _O.args[1]
77-
else
78-
return O
46+
if any(isone, O.args)
47+
args = filter(!isone, O.args)
48+
49+
isempty(args) && return Constant(1)
50+
length(args) == 1 && return first(args)
51+
return Operation(O.op, args)
7952
end
80-
elseif O.op == identity
81-
return O.args[1]
82-
elseif O.op == (-) && length(O.args) == 1
83-
return Operation(*,Expression[-1,O.args[1]])
84-
else
53+
8554
return O
8655
end
56+
57+
if O.op === (+) && any(iszero, O.args)
58+
# If there are Constant(0)s in a big `+` expression, get rid of them
59+
args = filter(!iszero, O.args)
60+
61+
isempty(args) && return Constant(0)
62+
length(args) == 1 && return first(args)
63+
return Operation(O.op, args)
64+
end
65+
66+
(O.op, length(O.args)) === (identity, 1) && return O.args[1]
67+
68+
(O.op, length(O.args)) === (-, 1) && return Operation(*, Expression[-1, O.args[1]])
69+
8770
return O
8871
end
8972
simplify_constants(x::Variable,y=false) = x

0 commit comments

Comments
 (0)