Skip to content

Commit 8725370

Browse files
authored
Add sort kwarg in arguments to avoid extra computation (#337)
1 parent dcd3b3a commit 8725370

File tree

4 files changed

+21
-9
lines changed

4 files changed

+21
-9
lines changed

src/api.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,21 +57,22 @@ function substitute(expr, dict; fold=true)
5757
haskey(dict, expr) && return dict[expr]
5858

5959
if istree(expr)
60+
op = substitute(operation(expr), dict; fold=fold)
6061
if fold
61-
canfold = !(operation(expr) isa Symbolic)
62-
args = map(arguments(expr)) do x
62+
canfold = !(op isa Symbolic)
63+
args = map(unsorted_arguments(expr)) do x
6364
x′ = substitute(x, dict; fold=fold)
6465
canfold = canfold && !(x′ isa Symbolic)
6566
x′
6667
end
67-
canfold && return operation(expr)(args...)
68+
canfold && return op(args...)
6869
args
6970
else
70-
args = map(x->substitute(x, dict, fold=fold), arguments(expr))
71+
args = map(x->substitute(x, dict, fold=fold), unsorted_arguments(expr))
7172
end
7273

7374
similarterm(expr,
74-
substitute(operation(expr), dict, fold=fold),
75+
op,
7576
args,
7677
symtype(expr),
7778
metadata=metadata(expr))

src/code.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ function get_assignments(d::DestructuredArgs, st)
168168
ex = (i isa Symbol ? :($name.$i) : :($name[$i]))
169169
ex = d.inbounds ? :(@inbounds($ex)) : ex
170170
a ex
171-
end
171+
end
172172
end
173173

174174
@matchable struct Let

src/types.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ end
353353

354354
operation(x::Term) = getfield(x, :f)
355355

356+
unsorted_arguments(x) = arguments(x)
356357
arguments(x::Term) = getfield(x, :arguments)
357358

358359
function Base.isequal(t1::Term, t2::Term)
@@ -632,6 +633,11 @@ istree(a::Add) = true
632633

633634
operation(a::Add) = +
634635

636+
function unsorted_arguments(a::Add)
637+
args = [v*k for (k,v) in a.dict]
638+
iszero(a.coeff) ? args : vcat(a.coeff, args)
639+
end
640+
635641
function arguments(a::Add)
636642
a.sorted_args_cache[] !== nothing && return a.sorted_args_cache[]
637643
args = sort!([v*k for (k,v) in a.dict], lt=<ₑ)
@@ -777,6 +783,11 @@ operation(a::Mul) = *
777783

778784
unstable_pow(a, b) = a isa Integer && b isa Integer ? (a//1) ^ b : a ^ b
779785

786+
function unsorted_arguments(a::Mul)
787+
args = [unstable_pow(k, v) for (k,v) in a.dict]
788+
isone(a.coeff) ? args : vcat(a.coeff, args)
789+
end
790+
780791
function arguments(a::Mul)
781792
a.sorted_args_cache[] !== nothing && return a.sorted_args_cache[]
782793
args = sort!([unstable_pow(k, v) for (k,v) in a.dict], lt=<ₑ)

test/code.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,13 @@ test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linen
150150
:((a,b,$(+)(a,b))))
151151

152152
@test SpawnFetch{Multithreaded}([()->1,()->2],vcat)|>toexpr|>eval == [1,2]
153-
@test @elapsed(SpawnFetch{Multithreaded}([:(()->sleep(.6)),
153+
@test @elapsed(SpawnFetch{Multithreaded}([:(()->sleep(2)),
154154
Func([:x],
155155
[],
156156
:(sleep(x)))],
157157
[(),
158-
(0.6,)],
159-
vcat)|>toexpr|>eval) < 1.1
158+
(2,)],
159+
vcat)|>toexpr|>eval) < 3
160160

161161
let
162162
@syms a b

0 commit comments

Comments
 (0)