Skip to content

Commit ee3a702

Browse files
authored
Merge pull request #166 from JuliaSymbolics/s/fast-sub
WIP: Attempt to speed up substitute
2 parents 7d60e4a + dcfca99 commit ee3a702

File tree

3 files changed

+51
-9
lines changed

3 files changed

+51
-9
lines changed

benchmark/benchmarks.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,19 @@ let r = @rule(~x => ~x), rs = RuleSet([r]),
5151
overhead["simplify"]["randterm (/, *):serial"] = @benchmarkable simplify($ex2, threaded=false)
5252
overhead["simplify"]["randterm (+, *):thread"] = @benchmarkable simplify($ex1, threaded=true)
5353
overhead["simplify"]["randterm (/, *):thread"] = @benchmarkable simplify($ex2, threaded=true)
54+
55+
overhead["substitute"] = BenchmarkGroup()
56+
57+
58+
overhead["substitute"]["a"] = @benchmarkable substitute(subs_expr, $(Dict(a=>1))) setup=begin
59+
subs_expr = (sin(a+b) + cos(b+c)) * (sin(b+c) + cos(c+a)) * (sin(c+a) + cos(a+b))
60+
end
61+
62+
overhead["substitute"]["a,b"] = @benchmarkable substitute(subs_expr, $(Dict(a=>1, b=>2))) setup=begin
63+
subs_expr = (sin(a+b) + cos(b+c)) * (sin(b+c) + cos(c+a)) * (sin(c+a) + cos(a+b))
64+
end
65+
66+
overhead["substitute"]["a,b,c"] = @benchmarkable substitute(subs_expr, $(Dict(a=>1, b=>2, c=>3))) setup=begin
67+
subs_expr = (sin(a+b) + cos(b+c)) * (sin(b+c) + cos(c+a)) * (sin(c+a) + cos(a+b))
68+
end
5469
end

src/api.jl

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,23 @@ substitute any subexpression that matches a key in `dict` with
3838
the corresponding value.
3939
"""
4040
function substitute(expr, dict; fold=true)
41-
rs = Prewalk(PassThrough(@rule ~x::(x->haskey(dict, x)) => dict[~x]))
42-
if fold
43-
rs(to_symbolic(expr)) |> SymbolicUtils.fold
41+
haskey(dict, expr) && return dict[expr]
42+
43+
if istree(expr)
44+
if fold
45+
canfold=true
46+
args = map(arguments(expr)) do x
47+
x′ = substitute(x, dict; fold=fold)
48+
canfold = canfold && !(x′ isa Symbolic)
49+
x′
50+
end
51+
canfold && return operation(expr)(args...)
52+
args
53+
else
54+
args = map(x->substitute(x, dict), arguments(expr))
55+
end
56+
similarterm(expr, operation(expr), args)
4457
else
45-
rs(to_symbolic(expr))
58+
expr
4659
end
4760
end

src/types.jl

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,8 @@ See [promote_symtype](#promote_symtype)
280280
struct Term{T} <: Symbolic{T}
281281
f::Any
282282
arguments::Any
283+
hash::Ref{UInt} # hash cache
284+
Term{T}(f, xs) where {T} = new{T}(f, xs, Ref{UInt}(0))
283285
end
284286

285287
istree(t::Term) = true
@@ -294,7 +296,9 @@ arguments(x::Term) = getfield(x, :arguments)
294296
hashvec(xs, z) = foldr(hash, xs, init=z)
295297

296298
function Base.hash(t::Term{T}, salt::UInt) where {T}
297-
hashvec(arguments(t), hash(operation(t), hash(T, salt)))
299+
h = t.hash[]
300+
!iszero(h) && return h
301+
t.hash[] = hashvec(arguments(t), hash(operation(t), hash(T, salt)))
298302
end
299303

300304
function term(f, args...; type = nothing)
@@ -408,6 +412,7 @@ struct Add{X, T<:Number, D} <: Symbolic{X}
408412
coeff::T
409413
dict::D
410414
sorted_args_cache::Ref{Any}
415+
hash::Ref{UInt}
411416
end
412417

413418
function Add(T, coeff, dict)
@@ -418,7 +423,7 @@ function Add(T, coeff, dict)
418423
return _isone(v) ? k : Mul(T, makemul(v, k)...)
419424
end
420425

421-
Add{T, typeof(coeff), typeof(dict)}(coeff, dict, Ref{Any}(nothing))
426+
Add{T, typeof(coeff), typeof(dict)}(coeff, dict, Ref{Any}(nothing), Ref{UInt}(0))
422427
end
423428

424429
symtype(a::Add{X}) where {X} = X
@@ -434,7 +439,11 @@ function arguments(a::Add)
434439
a.sorted_args_cache[] = iszero(a.coeff) ? args : vcat(a.coeff, args)
435440
end
436441

437-
Base.hash(a::Add, u::UInt64) = hash(a.coeff, hash(a.dict, u))
442+
function Base.hash(a::Add, u::UInt64)
443+
h = a.hash[]
444+
!iszero(h) && return h
445+
a.hash[] = hash(0xaddaddaddaddadda, hash(a.coeff, hash(a.dict, u)))
446+
end
438447

439448
Base.isequal(a::Add, b::Add) = isequal(a.coeff, b.coeff) && isequal(a.dict, b.dict)
440449

@@ -529,6 +538,7 @@ struct Mul{X, T<:Number, D} <: Symbolic{X}
529538
coeff::T
530539
dict::D
531540
sorted_args_cache::Ref{Any}
541+
hash::Ref{UInt}
532542
end
533543

534544
function Mul(T, a,b)
@@ -541,7 +551,7 @@ function Mul(T, a,b)
541551
return Pow(first(pair), last(pair))
542552
end
543553
else
544-
Mul{T, typeof(a), typeof(b)}(a,b, Ref{Any}(nothing))
554+
Mul{T, typeof(a), typeof(b)}(a,b, Ref{Any}(nothing), Ref{UInt}(0))
545555
end
546556
end
547557

@@ -557,7 +567,11 @@ function arguments(a::Mul)
557567
a.sorted_args_cache[] = isone(a.coeff) ? args : vcat(a.coeff, args)
558568
end
559569

560-
Base.hash(m::Mul, u::UInt64) = hash(m.coeff, hash(m.dict, u))
570+
function Base.hash(m::Mul, u::UInt64)
571+
h = m.hash[]
572+
!iszero(h) && return h
573+
m.hash[] = hash(0xaaaaaaaaaaaaaaa, hash(m.coeff, hash(m.dict, u)))
574+
end
561575

562576
Base.isequal(a::Mul, b::Mul) = isequal(a.coeff, b.coeff) && isequal(a.dict, b.dict)
563577

0 commit comments

Comments
 (0)