Skip to content

Commit a69f8e8

Browse files
Merge pull request #615 from JuliaSymbolics/b/613-default-arguments-to-unsorted_arguments-to-accelerate-term-traversal
Optimize `arguments` function by removing sorting
2 parents 1cdd436 + 274e3cf commit a69f8e8

File tree

12 files changed

+56
-38
lines changed

12 files changed

+56
-38
lines changed

src/SymbolicUtils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import TermInterface: iscall, isexpr, issym, symtype, head, children,
2121

2222
const istree = iscall
2323
Base.@deprecate_binding istree iscall
24-
export istree, operation, arguments, unsorted_arguments, similarterm, iscall
24+
export istree, operation, arguments, sorted_arguments, similarterm, iscall
2525
# Sym, Term,
2626
# Add, Mul and Pow
2727
include("types.jl")

src/code.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr,
99
import ..SymbolicUtils
1010
import ..SymbolicUtils.Rewriters
1111
import SymbolicUtils: @matchable, BasicSymbolic, Sym, Term, iscall, operation, arguments, issym,
12-
symtype, similarterm, unsorted_arguments, metadata, isterm, term, maketerm
12+
symtype, similarterm, sorted_arguments, metadata, isterm, term, maketerm
1313

1414
##== state management ==##
1515

@@ -124,7 +124,7 @@ end
124124
function function_to_expr(op::Union{typeof(*),typeof(+)}, O, st)
125125
out = get(st.rewrites, O, nothing)
126126
out === nothing || return out
127-
args = map(Base.Fix2(toexpr, st), arguments(O))
127+
args = map(Base.Fix2(toexpr, st), sorted_arguments(O))
128128
if length(args) >= 3 && symtype(O) <: Number
129129
x, xs = Iterators.peel(args)
130130
foldl(xs, init=x) do a, b
@@ -744,7 +744,7 @@ end
744744
function cse_state!(state, t)
745745
!iscall(t) && return t
746746
state[t] = Base.get(state, t, 0) + 1
747-
foreach(x->cse_state!(state, x), unsorted_arguments(t))
747+
foreach(x->cse_state!(state, x), arguments(t))
748748
end
749749

750750
function cse_block!(assignments, counter, names, name, state, x)
@@ -759,7 +759,7 @@ function cse_block!(assignments, counter, names, name, state, x)
759759
return sym
760760
end
761761
elseif iscall(x)
762-
args = map(a->cse_block!(assignments, counter, names, name, state,a), unsorted_arguments(x))
762+
args = map(a->cse_block!(assignments, counter, names, name, state,a), arguments(x))
763763
if isterm(x)
764764
return term(operation(x), args...)
765765
else

src/inspect.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,16 @@ function AbstractTrees.nodevalue(x::BasicSymbolic)
2626
Text(str)
2727
end
2828

29+
"""
30+
$(TYPEDSIGNATURES)
31+
32+
Return the children of the symbolic expression `x`, sorted by their order in
33+
the expression.
34+
35+
This function is used internally for printing via AbstractTrees.
36+
"""
2937
function AbstractTrees.children(x::Symbolic)
30-
iscall(x) ? arguments(x) : isexpr(x) ? children(x) : ()
38+
iscall(x) ? sorted_arguments(x) : isexpr(x) ? sorted_children(x) : ()
3139
end
3240

3341
"""

src/interface.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,22 +36,22 @@ is the function being called.
3636
function operation end
3737

3838
"""
39-
arguments(x)
39+
sorted_arguments(x)
4040
4141
Get the arguments of `x`, must be defined if `iscall(x)` is `true`.
4242
"""
43-
function arguments end
43+
function sorted_arguments end
4444

4545
"""
46-
unsorted_arguments(x::T)
46+
sorted_arguments(x::T)
4747
4848
If x is a term satisfying `iscall(x)` and your term type `T` provides
4949
an optimized implementation for storing the arguments, this function can
5050
be used to retrieve the arguments when the order of arguments does not matter
5151
but the speed of the operation does.
5252
"""
53-
unsorted_arguments(x) = arguments(x)
54-
arity(x) = length(unsorted_arguments(x))
53+
function arguments end
54+
arity(x) = length(arguments(x))
5555

5656
"""
5757
metadata(x)

src/ordering.jl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,28 +14,31 @@
1414
<(a::T, b::S) where{T,S} = T<S
1515
<(a::T, b::T) where{T} = a < b
1616

17+
"""
18+
$(SIGNATURES)
1719
18-
###### A variation on degree lexicographic order ########
19-
# find symbols and their corresponding degrees
20+
Internal function used for printing symbolic expressions. This function determines
21+
the degrees of symbols within a given expression, implementing a variation on
22+
degree lexicographic order.
23+
"""
2024
function get_degrees(expr)
2125
if issym(expr)
2226
((Symbol(expr),) => 1,)
2327
elseif iscall(expr)
2428
op = operation(expr)
25-
args = arguments(expr)
26-
if operation(expr) == (^) && args[2] isa Number
29+
args = sorted_arguments(expr)
30+
if op == (^) && args[2] isa Number
2731
return map(get_degrees(args[1])) do (base, pow)
2832
(base => pow * args[2])
2933
end
30-
elseif operation(expr) == (*)
34+
elseif op == (*)
3135
return mapreduce(get_degrees,
3236
(x,y)->(x...,y...,), args)
33-
elseif operation(expr) == (+)
37+
elseif op == (+)
3438
ds = map(get_degrees, args)
3539
_, idx = findmax(x->sum(last.(x), init=0), ds)
3640
return ds[idx]
37-
elseif operation(expr) == (getindex)
38-
args = arguments(expr)
41+
elseif op == (getindex)
3942
return ((Symbol.(args)...,) => 1,)
4043
else
4144
return ((Symbol("zzzzzzz", hash(expr)),) => 1,)
@@ -62,7 +65,7 @@ function lexlt(degs1, degs2)
6265
return false # they are equal
6366
end
6467

65-
_arglen(a) = iscall(a) ? length(unsorted_arguments(a)) : 0
68+
_arglen(a) = iscall(a) ? length(arguments(a)) : 0
6669

6770
function <(a::Tuple, b::Tuple)
6871
for (x, y) in zip(a, b)

src/polyform.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,9 @@ function arguments(x::PolyForm{T}) where {T}
231231
PolyForm{T}(t, x.pvar2sym, x.sym2term, nothing)) for t in ts]
232232
end
233233
end
234+
235+
sorted_arguments(x::PolyForm) = arguments(x)
236+
234237
children(x::PolyForm) = [operation(x); arguments(x)]
235238

236239
Base.show(io::IO, x::PolyForm) = show_term(io, x)
@@ -344,7 +347,7 @@ end
344347

345348
function add_with_div(x, flatten=true)
346349
(!iscall(x) || operation(x) != (+)) && return x
347-
aa = unsorted_arguments(x)
350+
aa = arguments(x)
348351
!any(a->isdiv(a), aa) && return x # no rewrite necessary
349352

350353
divs = filter(a->isdiv(a), aa)
@@ -382,12 +385,12 @@ end
382385

383386
function needs_div_rules(x)
384387
(isdiv(x) && !(x.num isa Number) && !(x.den isa Number)) ||
385-
(iscall(x) && operation(x) === (+) && count(has_div, unsorted_arguments(x)) > 1) ||
386-
(iscall(x) && any(needs_div_rules, unsorted_arguments(x)))
388+
(iscall(x) && operation(x) === (+) && count(has_div, arguments(x)) > 1) ||
389+
(iscall(x) && any(needs_div_rules, arguments(x)))
387390
end
388391

389392
function has_div(x)
390-
return isdiv(x) || (iscall(x) && any(has_div, unsorted_arguments(x)))
393+
return isdiv(x) || (iscall(x) && any(has_div, arguments(x)))
391394
end
392395

393396
flatten_pows(xs) = map(xs) do x
@@ -415,8 +418,8 @@ Has optimized processes for `Mul` and `Pow` terms.
415418
function quick_cancel(d)
416419
if ispow(d) && isdiv(d.base)
417420
return quick_cancel((d.base.num^d.exp) / (d.base.den^d.exp))
418-
elseif ismul(d) && any(isdiv, unsorted_arguments(d))
419-
return prod(unsorted_arguments(d))
421+
elseif ismul(d) && any(isdiv, arguments(d))
422+
return prod(arguments(d))
420423
elseif isdiv(d)
421424
num, den = quick_cancel(d.num, d.den)
422425
return Div(num, den)

src/rewriters.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ module Rewriters
3333
using SymbolicUtils: @timer
3434
using TermInterface
3535

36-
import SymbolicUtils: iscall, operation, arguments, unsorted_arguments, metadata, node_count, _promote_symtype
36+
import SymbolicUtils: iscall, operation, arguments, sorted_arguments, metadata, node_count, _promote_symtype
3737
export Empty, IfElse, If, Chain, RestartedChain, Fixpoint, Postwalk, Prewalk, PassThrough
3838

3939
# Cache of printed rules to speed up @timer
@@ -221,7 +221,7 @@ function (p::Walk{ord, C, F, false})(x) where {ord, C, F}
221221

222222
if iscall(x)
223223
x = p.maketerm(x, operation(x), map(PassThrough(p),
224-
unsorted_arguments(x)), metadata=metadata(x))
224+
arguments(x)), metadata=metadata(x))
225225
end
226226

227227
return ord === :post ? p.rw(x) : x

src/rule.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ function (acr::ACRule)(term)
399399
end
400400

401401
T = symtype(term)
402-
args = unsorted_arguments(term)
402+
args = arguments(term)
403403

404404
itr = acr.sets(eachindex(args), acr.arity)
405405

src/simplify.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,6 @@ end
4545

4646
has_operation(x, op) = (iscall(x) && (operation(x) == op ||
4747
any(a->has_operation(a, op),
48-
unsorted_arguments(x))))
48+
arguments(x))))
4949

5050
Base.@deprecate simplify(x, ctx; kwargs...) simplify(x; rewriter=ctx, kwargs...)

src/substitute.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@ function substitute(expr, dict; fold=true)
2020
op = substitute(operation(expr), dict; fold=fold)
2121
if fold
2222
canfold = !(op isa Symbolic)
23-
args = map(unsorted_arguments(expr)) do x
23+
args = map(arguments(expr)) do x
2424
x′ = substitute(x, dict; fold=fold)
2525
canfold = canfold && !(x′ isa Symbolic)
2626
x′
2727
end
2828
canfold && return op(args...)
2929
args
3030
else
31-
args = map(x->substitute(x, dict, fold=fold), unsorted_arguments(expr))
31+
args = map(x->substitute(x, dict, fold=fold), arguments(expr))
3232
end
3333

3434
maketerm(typeof(expr),

0 commit comments

Comments
 (0)