Skip to content

Commit bf54f40

Browse files
authored
Merge branch 'master' into compathelper/new_version/2021-10-24-18-48-57-857-02801455533
2 parents 14e97ba + 701ce7f commit bf54f40

File tree

7 files changed

+233
-25
lines changed

7 files changed

+233
-25
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SymbolicUtils"
22
uuid = "d1185830-fcd6-423d-90d6-eec64667417b"
33
authors = ["Shashi Gowda"]
4-
version = "0.17.0"
4+
version = "0.18.0"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
@@ -36,7 +36,7 @@ DocStringExtensions = "0.8"
3636
DynamicPolynomials = "0.3"
3737
IfElse = "0.1"
3838
LabelledArrays = "1.5"
39-
Metatheory = "1.0"
39+
Metatheory = "1.2"
4040
MultivariatePolynomials = "0.3, 0.4"
4141
NaNMath = "0.3"
4242
Setfield = "0.7, 0.8"

src/SymbolicUtils.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,13 @@ include("simplify_rules.jl")
5757
export simplify, substitute
5858
include("api.jl")
5959

60+
# EGraph rewriting
61+
include("egraph.jl")
62+
export optimize
63+
6064
include("code.jl")
6165

66+
6267
# ADjoints
6368
include("adjoints.jl")
6469

src/egraph.jl

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
using Metatheory.Rewriters
2+
3+
function EGraphs.preprocess(t::Symbolic)
4+
toterm(unflatten(t))
5+
end
6+
7+
function symbolicegraph(ex)
8+
g = EGraph(ex)
9+
analyze!(g, SymbolicUtils.SymtypeAnalysis)
10+
settermtype!(g, Term{symtype(ex), Any})
11+
return g
12+
end
13+
14+
15+
"""
16+
Equational rewrite rules for optimizing expressions
17+
"""
18+
opt_theory = @theory a b x y begin
19+
a + b == b + a
20+
a * b == b * a
21+
a * x + a * y == a*(x+y)
22+
-1 * a == -a
23+
a + (-1 * b) == a - b
24+
x^-1 == 1/x
25+
1/x * a == a/x
26+
# fraction rules
27+
# (a/b) + (c/b) => (a+c)/b
28+
# trig functions
29+
sin(x)/cos(x) == tan(x)
30+
cos(x)/sin(x) == cot(x)
31+
sin(x)^2 + cos(x)^2 --> 1
32+
sin(2a) == 2sin(a)cos(a)
33+
end
34+
35+
36+
"""
37+
Approximation of costs of operators in number
38+
of CPU cycles required for the numerical computation
39+
40+
See
41+
* https://latkin.org/blog/2014/11/09/a-simple-benchmark-of-various-math-operations/
42+
* https://streamhpc.com/blog/2012-07-16/how-expensive-is-an-operation-on-a-cpu/
43+
* https://github.com/triscale-innov/GFlops.jl
44+
"""
45+
const op_costs = Dict(
46+
(+) => 1,
47+
(-) => 1,
48+
abs => 2,
49+
(*) => 3,
50+
exp => 18,
51+
(/) => 24,
52+
(^) => 100,
53+
log1p => 124,
54+
deg2rad => 125,
55+
rad2deg => 125,
56+
acos => 127,
57+
asind => 128,
58+
acsch => 133,
59+
sin => 134,
60+
cos => 134,
61+
atan => 135,
62+
tan => 156,
63+
)
64+
# TODO some operator costs are in FLOP and not in cycles!!
65+
66+
function costfun(n::ENodeTerm, g::EGraph, an)
67+
op = operation(n)
68+
cost = 0
69+
cost += get(op_costs, op, 1)
70+
71+
for id n.args
72+
eclass = g[id]
73+
!hasdata(eclass, an) && (cost += Inf; break)
74+
cost += last(getdata(eclass, an))
75+
end
76+
cost
77+
end
78+
79+
costfun(n::ENodeLiteral, g::EGraph, an) = 0
80+
81+
egraph_simterm(x, head, args, symtype=nothing; metadata=nothing, exprhead=exprhead(x)) =
82+
TermInterface.similarterm(typeof(x), head, args, symtype; metadata=metadata, exprhead=exprhead)
83+
84+
85+
# Custom similarterm to use in EGraphs on <:Symbolic types that treats everything as a Term
86+
function egraph_simterm(x::Type{<:Term}, f, args, symtype=nothing; metadata=nothing, exprhead=:call)
87+
T = symtype
88+
if T === nothing
89+
T = _promote_symtype(f, args)
90+
end
91+
res = Term{T}(f isa Symbol ? eval(f) : f, args; metadata=metadata);
92+
return res
93+
end
94+
95+
function optimize(ex; params=SaturationParams(timeout=20))
96+
# @show ex
97+
g = symbolicegraph(ex)
98+
params.simterm = egraph_simterm
99+
saturate!(g, opt_theory, params)
100+
return extract!(g, costfun)
101+
end

src/polyform.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,10 @@ function unpolyize(x)
236236
Postwalk(identity, similarterm=simterm)(x)
237237
end
238238

239+
function toterm(x::PolyForm)
240+
toterm(unpolyize(x))
241+
end
242+
239243
## Rational Polynomial form with Div
240244

241245
function polyform_factors(d, pvar2sym, sym2term)

0 commit comments

Comments
 (0)