|
1 |
| -function simplify_constants(O::Operation,shorten_tree = true) |
2 |
| - O_last = nothing |
3 |
| - _O = O |
4 |
| - while _O != O_last |
5 |
| - O_last = _O |
6 |
| - _O = _simplify_constants(_O,shorten_tree) |
7 |
| - if typeof(_O) <: Operation |
8 |
| - _O = Operation(_O.op,simplify_constants.(_O.args,shorten_tree)) |
| 1 | +function simplify_constants(O::Operation, shorten_tree = true) |
| 2 | + while true |
| 3 | + O′ = _simplify_constants(O, shorten_tree) |
| 4 | + if is_operation(O′) |
| 5 | + O′ = Operation(O′.op, simplify_constants.(O′.args, shorten_tree)) |
9 | 6 | end
|
| 7 | + O == O′ && return O |
| 8 | + O = O′ |
10 | 9 | end
|
11 |
| - _O |
12 | 10 | end
|
13 | 11 |
|
14 |
| -const TREE_SHRINK_OPS = [:*,:+] |
| 12 | +const AC_OPERATORS = (*, +) |
15 | 13 |
|
16 |
| -function _simplify_constants(O,shorten_tree = true) |
| 14 | +function _simplify_constants(O, shorten_tree = true) |
17 | 15 | # Tree shrinking
|
18 |
| - if shorten_tree |
19 |
| - for cur_op in TREE_SHRINK_OPS |
20 |
| - if Symbol(O.op) == cur_op |
21 |
| - # Shrink tree |
22 |
| - if any(x->typeof(x)<:Operation && Symbol(x.op) == cur_op ,O.args) |
23 |
| - idxs = findall(x->typeof(x)<:Operation && Symbol(x.op) == cur_op,O.args) |
24 |
| - keep_idxs = 1:length(O.args) .∉ (idxs,) |
25 |
| - args = Vector{Expression}[O.args[i].args for i in idxs] |
26 |
| - push!(args,O.args[keep_idxs]) |
27 |
| - return Operation(O.op,vcat(args...)) |
28 |
| - # Collapse constants |
29 |
| - elseif length(findall(x->typeof(x)<:Variable && x.subtype == :Constant ,O.args)) > 1 |
30 |
| - idxs = findall(x->typeof(x)<:Variable && x.subtype == :Constant ,O.args) |
31 |
| - other_idxs = 1:length(O.args) .∉ (idxs,) |
32 |
| - if cur_op == :* |
33 |
| - new_var = Constant(prod(x->x.value,O.args[idxs])) |
34 |
| - elseif cur_op == :+ |
35 |
| - new_var = Constant(sum(x->x.value,O.args[idxs])) |
36 |
| - end |
37 |
| - new_args = O.args[other_idxs] |
38 |
| - push!(new_args,new_var) |
39 |
| - if length(new_args) > 1 |
40 |
| - return Operation(O.op,new_args) |
41 |
| - else |
42 |
| - return new_args[1] |
43 |
| - end |
44 |
| - end |
| 16 | + if shorten_tree && O.op ∈ AC_OPERATORS |
| 17 | + # Flatten tree |
| 18 | + idxs = findall(x -> is_operation(x) && x.op === O.op, O.args) |
| 19 | + if !isempty(idxs) |
| 20 | + keep_idxs = eachindex(O.args) .∉ (idxs,) |
| 21 | + args = Vector{Expression}[O.args[i].args for i in idxs] |
| 22 | + push!(args, O.args[keep_idxs]) |
| 23 | + return Operation(O.op, vcat(args...)) |
| 24 | + end |
| 25 | + |
| 26 | + # Collapse constants |
| 27 | + idxs = findall(is_constant, O.args) |
| 28 | + if length(idxs) > 1 |
| 29 | + other_idxs = eachindex(O.args) .∉ (idxs,) |
| 30 | + new_const = Constant(mapreduce(get, O.op, O.args[idxs])) |
| 31 | + args = push!(O.args[other_idxs], new_const) |
| 32 | + |
| 33 | + length(args) == 1 && return first(args) |
| 34 | + return Operation(O.op, args) |
45 | 35 | end
|
46 |
| - end |
47 | 36 | end
|
48 | 37 |
|
49 |
| - if Symbol(O.op) == :* |
| 38 | + if O.op === (*) |
50 | 39 | # If any variable is `Constant(0)`, zero the whole thing
|
| 40 | + any(iszero, O.args) && return Constant(0) |
| 41 | + |
51 | 42 | # If any variable is `Constant(1)`, remove that `Constant(1)` unless
|
52 | 43 | # they are all `Constant(1)`, in which case simplify to a single variable
|
53 |
| - if any(x->typeof(x)<:Variable && (isequal(x,Constant(0)) || isequal(x,Constant(-0))),O.args) |
54 |
| - return Constant(0) |
55 |
| - elseif any(x->typeof(x)<:Variable && isequal(x,Constant(1)),O.args) |
56 |
| - idxs = findall(x->typeof(x)<:Variable && isequal(x,Constant(1)),O.args) |
57 |
| - _O = Operation(O.op,O.args[1:length(O.args) .∉ (idxs,)]) |
58 |
| - if isempty(_O.args) |
59 |
| - return Constant(1) |
60 |
| - elseif length(_O.args) == 1 |
61 |
| - return _O.args[1] |
62 |
| - else |
63 |
| - return _O |
64 |
| - end |
65 |
| - else |
66 |
| - return O |
67 |
| - end |
68 |
| - elseif Symbol(O.op) == :+ && any(x->typeof(x)<:Variable && |
69 |
| - (isequal(x,Constant(0)) || isequal(x,Constant(-0))),O.args) |
70 |
| - # If there are Constant(0)s in a big `+` expression, get rid of them |
71 |
| - idxs = findall(x->typeof(x)<:Variable && (isequal(x,Constant(0)) || isequal(x,Constant(-0))),O.args) |
72 |
| - _O = Operation(O.op,O.args[1:length(O.args) .∉ (idxs,)]) |
73 |
| - if isempty(_O.args) |
74 |
| - return Constant(0) |
75 |
| - elseif length(_O.args) == 1 |
76 |
| - return _O.args[1] |
77 |
| - else |
78 |
| - return O |
| 44 | + if any(isone, O.args) |
| 45 | + args = filter(!isone, O.args) |
| 46 | + |
| 47 | + isempty(args) && return Constant(1) |
| 48 | + length(args) == 1 && return first(args) |
| 49 | + return Operation(O.op, args) |
79 | 50 | end
|
80 |
| - elseif O.op == identity |
81 |
| - return O.args[1] |
82 |
| - elseif Symbol(O.op) == :- && length(O.args) == 1 |
83 |
| - return Operation(*,Expression[-1,O.args[1]]) |
84 |
| - else |
| 51 | + |
85 | 52 | return O
|
86 | 53 | end
|
| 54 | + |
| 55 | + if O.op === (+) && any(iszero, O.args) |
| 56 | + # If there are Constant(0)s in a big `+` expression, get rid of them |
| 57 | + args = filter(!iszero, O.args) |
| 58 | + |
| 59 | + isempty(args) && return Constant(0) |
| 60 | + length(args) == 1 && return first(args) |
| 61 | + return Operation(O.op, args) |
| 62 | + end |
| 63 | + |
| 64 | + (O.op, length(O.args)) === (identity, 1) && return O.args[1] |
| 65 | + |
| 66 | + (O.op, length(O.args)) === (-, 1) && return Operation(*, Expression[-1, O.args[1]]) |
| 67 | + |
87 | 68 | return O
|
88 | 69 | end
|
89 |
| -simplify_constants(x::Variable,y=false) = x |
90 |
| -_simplify_constants(x::Variable,y=false) = x |
| 70 | +simplify_constants(x::Variable, y=false) = x |
| 71 | +_simplify_constants(x::Variable, y=false) = x |
91 | 72 |
|
92 | 73 | export simplify_constants
|
0 commit comments