Skip to content

Commit 2e743f1

Browse files
authored
Multithreaded rules application (#66)
* initial attempt at multi-threaded rules * don't multithread by default * change cutoff * account for NaN in fuzz tests * increment version because this would drop support for pre 1.3 * add some tests * introduce function barrier in hopes that optimizer wont give up * fix * updated rule to use new interface * typo * change to thread_subtree_cutoff * add threaded benchmarks; fixup tests * additional function barrier * don't splat * typos * fix * reduce splatting * shuffle around logic * disable applyall on recurse * more disabling of applyall in recurse
1 parent 864e3e5 commit 2e743f1

File tree

9 files changed

+62
-16
lines changed

9 files changed

+62
-16
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ branches:
88
os:
99
- linux
1010
julia:
11-
- 1
11+
- 1.3.1
1212
- 1.4.1
1313
- nightly
1414
matrix:

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ Combinatorics = "1.0"
1414
NaNMath = "0.3"
1515
SpecialFunctions = "0.10"
1616
TimerOutputs = "0.5"
17-
julia = "1"
17+
julia = "1.3"
1818

1919
[extras]
2020
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"

benchmark/benchmarks.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ using Random
55

66
SUITE = BenchmarkGroup()
77

8-
@syms a b c
8+
@syms a b c d; Random.seed!(123);
9+
910
let r = @rule(~x => ~x), rs = RuleSet([r]),
1011
acr = @rule(~x::isnumber + ~y => ~y)
1112

@@ -37,4 +38,19 @@ let r = @rule(~x => ~x), rs = RuleSet([r]),
3738
overhead["simplify_no_fixp"]["noop:Int"] = @benchmarkable simplify(1, fixpoint=false)
3839
overhead["simplify_no_fixp"]["noop:Sym"] = @benchmarkable simplify($a, fixpoint=false)
3940
overhead["simplify_no_fixp"]["noop:Term"] = @benchmarkable simplify($(a+2), fixpoint=false)
41+
42+
function random_term(len; atoms, funs, fallback_atom=1)
43+
xs = rand(atoms, len)
44+
while length(xs) > 1
45+
xs = map(Iterators.partition(xs, 2)) do xy
46+
x = xy[1]; y = get(xy, 2, fallback_atom)
47+
rand(funs)(x, y)
48+
end
49+
end
50+
xs[]
51+
end
52+
ex = random_term(1000, atoms=[a, b, c, d, a^(-1), b^(-1), 1, 2.0], funs=[+, *])
53+
54+
overhead["simplify_no_fixp"]["randterm:serial"] = @benchmarkable simplify($ex, threaded=false, fixpoint=false)
55+
overhead["simplify_no_fixp"]["randterm:thread"] = @benchmarkable simplify($ex, threaded=true, fixpoint=false)
4056
end

src/SymbolicUtils.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ else
2727
end
2828
end
2929

30-
3130
export @syms, term, @fun, showraw
3231
include("types.jl")
3332

src/rule_dsl.jl

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -248,19 +248,41 @@ struct RuleRewriteError
248248
expr
249249
end
250250

251-
function (r::RuleSet)(term, context=EmptyCtx(); depth=typemax(Int), applyall=false, recurse=true)
251+
node_count(atom, count; cutoff) = count + 1
252+
node_count(t::Term, count=0; cutoff=100) = sum(node_count(arg, count; cutoff=cutoff) for arg arguments(t))
253+
254+
function _recurse_apply_ruleset_threaded(r::RuleSet, term, context; depth, thread_subtree_cutoff)
255+
_args = map(arguments(term)) do arg
256+
if node_count(arg) > thread_subtree_cutoff
257+
Threads.@spawn r(arg, context; depth=depth-1, threaded=true,
258+
thread_subtree_cutoff=thread_subtree_cutoff)
259+
else
260+
r(arg, context; depth=depth-1, threaded=false)
261+
end
262+
end
263+
args = map(t -> t isa Task ? fetch(t) : t, _args)
264+
Term{symtype(term)}(operation(term), args)
265+
end
266+
267+
function (r::RuleSet)(term, context=EmptyCtx(); depth=typemax(Int), applyall::Bool=false, recurse::Bool=true,
268+
threaded::Bool=false, thread_subtree_cutoff::Int=100)
252269
rules = r.rules
253270
term = to_symbolic(term)
254271
# simplify the subexpressions
255272
if depth == 0
256273
return term
257274
end
258275
if term isa Symbolic
259-
if term isa Term && recurse
260-
expr = Term{symtype(term)}(operation(term),
261-
map(t -> r(t, context, depth=depth-1), arguments(term)))
276+
expr = if term isa Term && recurse
277+
if threaded
278+
_recurse_apply_ruleset_threaded(r, term, context; depth=depth,
279+
thread_subtree_cutoff=thread_subtree_cutoff)
280+
else
281+
expr = Term{symtype(term)}(operation(term),
282+
map(t -> r(t, context, depth=depth-1), arguments(term)))
283+
end
262284
else
263-
expr = term
285+
term
264286
end
265287
for i in 1:length(rules)
266288
expr′ = try
@@ -282,6 +304,7 @@ function (r::RuleSet)(term, context=EmptyCtx(); depth=typemax(Int), applyall=fal
282304
return expr # no rule applied
283305
end
284306

307+
285308
getdepth(::RuleSet) = typemax(Int)
286309

287310
function fixpoint(f, x, ctx; kwargs...)

src/simplify.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ Applies them once if `fixpoint=false`.
3333
The `applyall` and `recurse` keywords are forwarded to the enclosed
3434
`RuleSet`, they are mainly used for internal optimization.
3535
"""
36-
function simplify(x, ctx=EmptyCtx(); rules=default_rules(x, ctx), fixpoint=true, applyall=true, recurse=true, kwargs...)
36+
function simplify(x, ctx=EmptyCtx(); rules=default_rules(x, ctx), fixpoint=true, applyall=true, kwargs...)
3737
if fixpoint
38-
SymbolicUtils.fixpoint(rules, x, ctx; recurse=recurse, applyall=recurse)
38+
SymbolicUtils.fixpoint(rules, x, ctx; applyall=applyall)
3939
else
40-
rules(x, ctx; recurse=recurse, applyall=recurse)
40+
rules(x, ctx; applyall=applyall, kwargs...)
4141
end
4242
end
4343

test/benchmark.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# A little trick for travis
2-
using PkgBenchmark
2+
using PkgBenchmark, SymbolicUtils
33

44
pkgpath = dirname(dirname(pathof(SymbolicUtils)))
55
# move it out of the repository so that you can check out different branches

test/fuzz.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
include("fuzzlib.jl")
32

43
using Random: seed!

test/rulesets.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ using SymbolicUtils: fixpoint, getdepth
1313
ex = 2 * (w+w+α+β)
1414

1515
@eqtest rset(ex) == (((2 * w) + (2 * w)) + (2 * α)) + (2 * β)
16-
@eqtest rset(ex) == simplify(ex, rset; fixpoint=false, applyall=false)
17-
16+
@eqtest rset(ex) == simplify(ex; rules=rset, fixpoint=false, applyall=false)
1817
@eqtest fixpoint(rset, ex, "ctx") == ((2 * (2 * w)) + (2 * α)) + (2 * β)
1918
end
2019

@@ -103,6 +102,16 @@ end
103102
@test sprint(io->Base.showerror(io, err)) == "Failed to apply rule ~x + ~(y::pred) => ~x on expression a + b"
104103
end
105104

105+
@testset "Threading" begin
106+
@syms a b c d
107+
ex = (((0.6666666666666666 / (c / 1)) + ((1 * a) / (c / 1))) +
108+
(1.0 / (((1 * d) / (1 + b)) * (1 / b)))) +
109+
((((1 * a) + (1 * a)) / ((2.0 * (d + 1)) / 1.0)) +
110+
((((d * 1) / (1 + c)) * 2.0) / ((1 / d) + (1 / c))))
111+
@eqtest simplify(ex) == simplify(ex, threaded=true, thread_subtree_cutoff=3)
112+
@test SymbolicUtils.node_count(a + b * c / d) == 4
113+
end
114+
106115
@testset "timerwrite" begin
107116
@syms a b c d
108117
expr1 = foldr((x,y)->rand([*, /])(x,y), rand([a,b,c,d], 100))

0 commit comments

Comments
 (0)