Skip to content

Commit 701ce7f

Browse files
authored
Merge pull request #389 from JuliaSymbolics/ale/egraphs
[WIP] EGraphs optimization
2 parents 7078007 + 847c86f commit 701ce7f

File tree

7 files changed

+232
-24
lines changed

7 files changed

+232
-24
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ DocStringExtensions = "0.8"
3636
DynamicPolynomials = "0.3"
3737
IfElse = "0.1"
3838
LabelledArrays = "1.5"
39-
Metatheory = "1.1"
39+
Metatheory = "1.2"
4040
MultivariatePolynomials = "0.3"
4141
NaNMath = "0.3"
4242
Setfield = "0.7, 0.8"

src/SymbolicUtils.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,13 @@ include("simplify_rules.jl")
5757
export simplify, substitute
5858
include("api.jl")
5959

60+
# EGraph rewriting
61+
include("egraph.jl")
62+
export optimize
63+
6064
include("code.jl")
6165

66+
6267
# ADjoints
6368
include("adjoints.jl")
6469

src/egraph.jl

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
using Metatheory.Rewriters
2+
3+
function EGraphs.preprocess(t::Symbolic)
4+
toterm(unflatten(t))
5+
end
6+
7+
function symbolicegraph(ex)
8+
g = EGraph(ex)
9+
analyze!(g, SymbolicUtils.SymtypeAnalysis)
10+
settermtype!(g, Term{symtype(ex), Any})
11+
return g
12+
end
13+
14+
15+
"""
16+
Equational rewrite rules for optimizing expressions
17+
"""
18+
opt_theory = @theory a b x y begin
19+
a + b == b + a
20+
a * b == b * a
21+
a * x + a * y == a*(x+y)
22+
-1 * a == -a
23+
a + (-1 * b) == a - b
24+
x^-1 == 1/x
25+
1/x * a == a/x
26+
# fraction rules
27+
# (a/b) + (c/b) => (a+c)/b
28+
# trig functions
29+
sin(x)/cos(x) == tan(x)
30+
cos(x)/sin(x) == cot(x)
31+
sin(x)^2 + cos(x)^2 --> 1
32+
sin(2a) == 2sin(a)cos(a)
33+
end
34+
35+
36+
"""
37+
Approximation of costs of operators in number
38+
of CPU cycles required for the numerical computation
39+
40+
See
41+
* https://latkin.org/blog/2014/11/09/a-simple-benchmark-of-various-math-operations/
42+
* https://streamhpc.com/blog/2012-07-16/how-expensive-is-an-operation-on-a-cpu/
43+
* https://github.com/triscale-innov/GFlops.jl
44+
"""
45+
const op_costs = Dict(
46+
(+) => 1,
47+
(-) => 1,
48+
abs => 2,
49+
(*) => 3,
50+
exp => 18,
51+
(/) => 24,
52+
(^) => 100,
53+
log1p => 124,
54+
deg2rad => 125,
55+
rad2deg => 125,
56+
acos => 127,
57+
asind => 128,
58+
acsch => 133,
59+
sin => 134,
60+
cos => 134,
61+
atan => 135,
62+
tan => 156,
63+
)
64+
# TODO some operator costs are in FLOP and not in cycles!!
65+
66+
function costfun(n::ENodeTerm, g::EGraph, an)
67+
op = operation(n)
68+
cost = 0
69+
cost += get(op_costs, op, 1)
70+
71+
for id n.args
72+
eclass = g[id]
73+
!hasdata(eclass, an) && (cost += Inf; break)
74+
cost += last(getdata(eclass, an))
75+
end
76+
cost
77+
end
78+
79+
costfun(n::ENodeLiteral, g::EGraph, an) = 0
80+
81+
egraph_simterm(x, head, args, symtype=nothing; metadata=nothing, exprhead=exprhead(x)) =
82+
TermInterface.similarterm(typeof(x), head, args, symtype; metadata=metadata, exprhead=exprhead)
83+
84+
85+
# Custom similarterm to use in EGraphs on <:Symbolic types that treats everything as a Term
86+
function egraph_simterm(x::Type{<:Term}, f, args, symtype=nothing; metadata=nothing, exprhead=:call)
87+
T = symtype
88+
if T === nothing
89+
T = _promote_symtype(f, args)
90+
end
91+
res = Term{T}(f isa Symbol ? eval(f) : f, args; metadata=metadata);
92+
return res
93+
end
94+
95+
function optimize(ex; params=SaturationParams(timeout=20))
96+
# @show ex
97+
g = symbolicegraph(ex)
98+
params.simterm = egraph_simterm
99+
saturate!(g, opt_theory, params)
100+
return extract!(g, costfun)
101+
end

src/polyform.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,10 @@ function unpolyize(x)
236236
Postwalk(identity, similarterm=simterm)(x)
237237
end
238238

239+
function toterm(x::PolyForm)
240+
toterm(unpolyize(x))
241+
end
242+
239243
## Rational Polynomial form with Div
240244

241245
function polyform_factors(d, pvar2sym, sym2term)

src/types.jl

Lines changed: 100 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,25 @@ Base.isequal(::Symbolic, ::Symbolic) = false
6767

6868
### End of interface
6969

70+
### Metatheory.jl e-graph rewriting integration
71+
72+
"""
73+
SymtypeAnalysis
74+
75+
This abstract type is used to identify the EGraph analysis
76+
that keeps track of symtype through an EGraph. This must
77+
be added to every EGraph that is used in SymbolicUtils.
78+
"""
79+
abstract type SymtypeAnalysis <: AbstractAnalysis end
80+
_getsymtype(T::Type{<:Symbolic{X}}) where X = X
81+
_getsymtype(T::Type{X}) where {X} = X
82+
EGraphs.make(an::Type{SymtypeAnalysis}, g::EGraph, n::ENodeLiteral) = symtype(n.value)
83+
EGraphs.make(an::Type{SymtypeAnalysis}, g::EGraph, n::ENodeTerm{T}) where {T} = _getsymtype(T)
84+
EGraphs.join(an::Type{SymtypeAnalysis}, A, B) = Union{A, B}
85+
86+
# TODO JOIN egraph analysis
87+
TermInterface.symtype(ec::EClass) = getdata(ec, SymtypeAnalysis, Any)
88+
7089
function to_symbolic(x)
7190
Base.depwarn("`to_symbolic(x)` is deprecated, define the interface for your " *
7291
"symbolic structure using `istree(x)`, `operation(x)`, `arguments(x)` " *
@@ -348,6 +367,24 @@ function term(f, args...; type = nothing)
348367
Term{T}(f, [args...])
349368
end
350369

370+
"""
371+
unflatten(t::Symbolic{T})
372+
Binarizes `Term`s with n-ary operations
373+
"""
374+
function unflatten(t::Symbolic{T}) where{T}
375+
if istree(t)
376+
f = operation(t)
377+
if f == (+) || f == (*) # TODO check out for other n-ary --> binary ops
378+
a = arguments(t)
379+
return foldl((x,y) -> Term{T}(f, [x, y]), a)
380+
end
381+
end
382+
return t
383+
end
384+
385+
unflatten(t) = t
386+
387+
351388
"""
352389
similarterm(t, f, args, symtype; metadata=nothing)
353390
@@ -366,10 +403,17 @@ different type than `t`, because `f` also influences the result.
366403
"""
367404
TermInterface.similarterm(t::Type{<:Symbolic}, f, args; metadata=nothing, exprhead=:call) =
368405
similarterm(t, f, args, _promote_symtype(f, args); metadata=metadata, exprhead=exprhead)
369-
406+
407+
TermInterface.similarterm(t::Type{<:Symbolic}, f::Symbol, args; metadata=nothing, exprhead=:call) =
408+
TermInterface.similarterm(t, eval(f), args; metadata=metadata, exprhead=exprhead)
409+
370410
TermInterface.similarterm(t::Type{<:Term}, f, args, symtype; metadata=nothing, exprhead=:call) =
371411
Term{_promote_symtype(f, args)}(f, args; metadata=metadata)
372412

413+
TermInterface.similarterm(t::Type{<:Term}, f::Symbol, args, symtype; metadata=nothing, exprhead=:call) =
414+
Term{_promote_symtype(eval(f), args)}(eval(f), args; metadata=metadata)
415+
416+
373417
#--------------------
374418
#--------------------
375419
#### Pretty printing
@@ -549,6 +593,11 @@ showraw(t) = showraw(stdout, t)
549593
sdict(kv...) = Dict{Any, Number}(kv...)
550594

551595
const SN = Symbolic{<:Number}
596+
# TODO Reviewme this is necessary for Metatheory.jl egraph rewriting
597+
# integration. Constructors of `Add, Mul, Pow...` from Base (+, *, ^, ...)
598+
# Should now accepts EClasses as arguments.
599+
const SN_EC = Union{SN, EClass}
600+
552601
"""
553602
Add(T, coeff, dict::Dict)
554603
@@ -583,7 +632,6 @@ end
583632

584633
TermInterface.symtype(a::Add{X}) where {X} = X
585634

586-
587635
TermInterface.istree(a::Type{Add}) = true
588636

589637
TermInterface.operation(a::Add) = +
@@ -603,6 +651,17 @@ Base.isequal(a::Add, b::Add) = a.coeff == b.coeff && isequal(a.dict, b.dict)
603651

604652
Base.show(io::IO, a::Add) = show_term(io, a)
605653

654+
function toterm(t::Add{T}) where T
655+
args = []
656+
for (k, coeff) in t.dict
657+
push!(args, coeff == 1 ? k : Term{T}(*, [coeff, k]))
658+
end
659+
Term{T}(+, args)
660+
end
661+
662+
toterm(t) = t
663+
664+
606665
"""
607666
makeadd(sign, coeff::Number, xs...)
608667
@@ -641,7 +700,7 @@ add_t(a,b) = promote_symtype(+, symtype(a), symtype(b))
641700
sub_t(a,b) = promote_symtype(-, symtype(a), symtype(b))
642701
sub_t(a) = promote_symtype(-, symtype(a))
643702

644-
function +(a::SN, b::SN)
703+
function +(a::SN_EC, b::SN_EC)
645704
if a isa Add
646705
coeff, dict = makeadd(1, 0, b)
647706
T = promote_symtype(+, symtype(a), symtype(b))
@@ -652,11 +711,11 @@ function +(a::SN, b::SN)
652711
Add(add_t(a,b), makeadd(1, 0, a, b)...)
653712
end
654713

655-
+(a::Number, b::SN) = Add(add_t(a,b), makeadd(1, a, b)...)
714+
+(a::Number, b::SN_EC) = Add(add_t(a,b), makeadd(1, a, b)...)
656715

657-
+(a::SN, b::Number) = Add(add_t(a,b), makeadd(1, b, a)...)
716+
+(a::SN_EC, b::Number) = Add(add_t(a,b), makeadd(1, b, a)...)
658717

659-
+(a::SN) = a
718+
+(a::SN_EC) = a
660719

661720
+(a::Add, b::Add) = Add(add_t(a,b),
662721
a.coeff + b.coeff,
@@ -668,17 +727,17 @@ end
668727

669728
-(a::Add) = Add(sub_t(a), -a.coeff, mapvalues((_,v) -> -v, a.dict))
670729

671-
-(a::SN) = Add(sub_t(a), makeadd(-1, 0, a)...)
730+
-(a::SN_EC) = Add(sub_t(a), makeadd(-1, 0, a)...)
672731

673732
-(a::Add, b::Add) = Add(sub_t(a,b),
674733
a.coeff - b.coeff,
675734
_merge(-, a.dict, b.dict, filter=_iszero))
676735

677-
-(a::SN, b::SN) = a + (-b)
736+
-(a::SN_EC, b::SN_EC) = a + (-b)
678737

679-
-(a::Number, b::SN) = a + (-b)
738+
-(a::Number, b::SN_EC) = a + (-b)
680739

681-
-(a::SN, b::Number) = a + (-b)
740+
-(a::SN_EC, b::Number) = a + (-b)
682741

683742
"""
684743
Mul(T, coeff, dict)
@@ -753,6 +812,16 @@ Base.isequal(a::Mul, b::Mul) = a.coeff == b.coeff && isequal(a.dict, b.dict)
753812

754813
Base.show(io::IO, a::Mul) = show_term(io, a)
755814

815+
function toterm(t::Mul{T}) where T
816+
args = []
817+
push!(args, t.coeff)
818+
for (k, deg) in t.dict
819+
push!(args, deg == 1 ? k : Term{T}(^, [k, deg]))
820+
end
821+
Term{T}(*, args)
822+
end
823+
824+
756825
function makemul(coeff, xs...; d=sdict())
757826
for x in xs
758827
if x isa Pow && x.exp isa Number
@@ -777,9 +846,9 @@ end
777846
mul_t(a,b) = promote_symtype(*, symtype(a), symtype(b))
778847
mul_t(a) = promote_symtype(*, symtype(a))
779848

780-
*(a::SN) = a
849+
*(a::SN_EC) = a
781850

782-
function *(a::SN, b::SN)
851+
function *(a::SN_EC, b::SN_EC)
783852
# Always make sure Div wraps Mul
784853
if a isa Div && b isa Div
785854
Div(a.num * b.num, a.den * b.den)
@@ -796,7 +865,7 @@ end
796865
a.coeff * b.coeff,
797866
_merge(+, a.dict, b.dict, filter=_iszero))
798867

799-
function *(a::Number, b::SN)
868+
function *(a::Number, b::SN_EC)
800869
if iszero(a)
801870
a
802871
elseif isone(a)
@@ -812,17 +881,17 @@ function *(a::Number, b::SN)
812881
end
813882
end
814883

815-
*(a::SN, b::Number) = b * a
884+
*(a::SN_EC, b::Number) = b * a
816885

817-
\(a::SN, b::Union{Number, SN}) = b / a
886+
\(a::SN_EC, b::Union{Number, SN_EC}) = b / a
818887

819-
\(a::Number, b::SN) = b / a
888+
\(a::Number, b::SN_EC) = b / a
820889

821-
/(a::SN, b::Number) = (b isa Integer ? 1//b : inv(b)) * a
890+
/(a::SN_EC, b::Number) = (b isa Integer ? 1//b : inv(b)) * a
822891

823-
//(a::Union{SN, Number}, b::SN) = a / b
892+
//(a::Union{SN_EC, Number}, b::SN_EC) = a / b
824893

825-
//(a::SN, b::T) where {T <: Number} = (one(T) // b) * a
894+
//(a::SN_EC, b::T) where {T <: Number} = (one(T) // b) * a
826895

827896
"""
828897
Div(numerator_factors, denominator_factors, simplified=false)
@@ -901,7 +970,11 @@ end
901970

902971
Base.show(io::IO, d::Div) = show_term(io, d)
903972

904-
/(a::Union{SN,Number}, b::SN) = Div(a,b)
973+
function toterm(t::Div{T}) where T
974+
Term{T}(/, [t.num, t.den])
975+
end
976+
977+
/(a::Union{SN_EC,Number}, b::SN_EC) = Div(a,b)
905978

906979
"""
907980
Pow(base, exp)
@@ -944,6 +1017,10 @@ Base.isequal(p::Pow, b::Pow) = isequal(p.base, b.base) && isequal(p.exp, b.exp)
9441017

9451018
Base.show(io::IO, p::Pow) = show_term(io, p)
9461019

1020+
function toterm(t::Pow{T}) where T
1021+
Term{T}(^, [t.base, t.exp])
1022+
end
1023+
9471024
function makepow(a, b)
9481025
base = a
9491026
exp = b
@@ -954,11 +1031,11 @@ function makepow(a, b)
9541031
return (base, exp)
9551032
end
9561033

957-
^(a::SN, b) = Pow(a, b)
1034+
^(a::SN_EC, b) = Pow(a, b)
9581035

959-
^(a::SN, b::SN) = Pow(a, b)
1036+
^(a::SN_EC, b::SN_EC) = Pow(a, b)
9601037

961-
^(a::Number, b::SN) = Pow(a, b)
1038+
^(a::Number, b::SN_EC) = Pow(a, b)
9621039

9631040
function ^(a::Mul, b::Number)
9641041
coeff = unstable_pow(a.coeff, b)

0 commit comments

Comments
 (0)