Skip to content

Commit 56743fa

Browse files
authored
Merge pull request #723 from AayushSabharwal/as/small-array
feat: optimize small arrays in `BasicSymbolic`
2 parents ab977fc + 65cc52d commit 56743fa

File tree

4 files changed

+157
-29
lines changed

4 files changed

+157
-29
lines changed

docs/src/manual/rewrite.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ If you want to match a variable number of subexpressions at once, you will need
7171
@rule(+(~~xs) => ~~xs)(x + y + z)
7272
7373
# output
74-
3-element view(::Vector{Any}, 1:3) with eltype Any:
74+
3-element view(::SymbolicUtils.SmallVec{Any, Vector{Any}}, 1:3) with eltype Any:
7575
z
7676
y
7777
x

src/SymbolicUtils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ import TaskLocalValues: TaskLocalValue
2525

2626
include("cache.jl")
2727
Base.@deprecate istree iscall
28+
29+
include("small_array.jl")
30+
2831
export istree, operation, arguments, sorted_arguments, iscall
2932
# Sym, Term,
3033
# Add, Mul and Pow

src/small_array.jl

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
"""
2+
$(TYPEDEF)
3+
4+
A mutable resizeable 3-length vector to implement non-allocating small vectors.
5+
6+
$(TYPEDFIELDS)
7+
"""
8+
mutable struct Backing{T} <: AbstractVector{T}
9+
"""
10+
Length of the buffer.
11+
"""
12+
len::Int
13+
x1::T
14+
x2::T
15+
x3::T
16+
17+
Backing{T}() where {T} = new{T}(0)
18+
Backing{T}(x1) where {T} = new{T}(1, x1)
19+
Backing{T}(x1, x2) where {T} = new{T}(2, x1, x2)
20+
Backing{T}(x1, x2, x3) where {T} = new{T}(3, x1, x2, x3)
21+
end
22+
23+
Base.size(x::Backing) = (x.len,)
24+
Base.isempty(x::Backing) = x.len == 0
25+
26+
"""
27+
$(TYPEDSIGNATURES)
28+
29+
Value to use when removing an element from a `Backing`. Used so stored entries can be
30+
GC'ed when removed.
31+
"""
32+
defaultval(::Type{T}) where {T <: Number} = zero(T)
33+
defaultval(::Type{Any}) = nothing
34+
35+
function Base.getindex(x::Backing, i::Int)
36+
@boundscheck 1 <= i <= x.len
37+
if i == 1
38+
x.x1
39+
elseif i == 2
40+
x.x2
41+
elseif i == 3
42+
x.x3
43+
end
44+
end
45+
46+
function Base.setindex!(x::Backing, v, i::Int)
47+
@boundscheck 1 <= i <= x.len
48+
if i == 1
49+
setfield!(x, :x1, v)
50+
elseif i == 2
51+
setfield!(x, :x2, v)
52+
elseif i == 3
53+
setfield!(x, :x3, v)
54+
end
55+
end
56+
57+
function Base.push!(x::Backing, v)
58+
x.len < 3 || throw(ArgumentError("`Backing` is full"))
59+
x.len += 1
60+
x[x.len] = v
61+
end
62+
63+
function Base.pop!(x::Backing{T}) where {T}
64+
x.len > 0 || throw(ArgumentError("Array is empty"))
65+
v = x[x.len]
66+
x[x.len] = defaultval(T)
67+
x.len -= 1
68+
v
69+
end
70+
71+
"""
72+
$(TYPEDSIGNATURES)
73+
74+
Whether the `Backing` is full.
75+
"""
76+
isfull(x::Backing) = x.len == 3
77+
78+
"""
79+
$(TYPEDSIGNATURES)
80+
81+
A small-buffer-optimized `AbstractVector`. Uses a `Backing` when the number of elements
82+
is within the size of `Backing`, and allocates a `V` when the number of elements exceed
83+
this limit.
84+
"""
85+
mutable struct SmallVec{T, V <: AbstractVector{T}} <: AbstractVector{T}
86+
data::Union{Backing{T}, V}
87+
88+
function SmallVec{T}(x::AbstractVector{T}) where {T}
89+
V = typeof(x)
90+
if length(x) < 4
91+
new{T, V}(Backing{T}(x...))
92+
else
93+
new{T, V}(x)
94+
end
95+
end
96+
97+
function SmallVec{T, V}() where {T, V}
98+
new{T, V}(Backing{T}())
99+
end
100+
101+
function SmallVec{T, V}(x::Union{Tuple, AbstractVector}) where {T, V}
102+
if length(x) <= 3
103+
new{T, V}(Backing{T}(x...))
104+
else
105+
new{T, V}(V(x isa Tuple ? collect(x) : x))
106+
end
107+
end
108+
end
109+
110+
Base.convert(::Type{SmallVec{T, V}}, x::V) where {T, V} = SmallVec{T}(x)
111+
Base.convert(::Type{SmallVec{T, V}}, x) where {T, V} = SmallVec{T}(V(x))
112+
Base.convert(::Type{SmallVec{T, V}}, x::SmallVec{T, V}) where {T, V} = x
113+
114+
Base.size(x::SmallVec) = size(x.data)
115+
Base.isempty(x::SmallVec) = isempty(x.data)
116+
Base.getindex(x::SmallVec, i::Int) = x.data[i]
117+
Base.setindex!(x::SmallVec, v, i::Int) = setindex!(x.data, v, i)
118+
119+
function Base.push!(x::SmallVec{T, V}, v) where {T, V}
120+
buf = x.data
121+
buf isa Backing{T} || return push!(buf::V, v)
122+
isfull(buf) || return push!(buf::Backing{T}, v)
123+
x.data = V(buf)
124+
return push!(x.data::V, v)
125+
end
126+
127+
Base.pop!(x::SmallVec) = pop!(x.data)
128+
129+
function Base.sizehint!(x::SmallVec{T, V}, n; kwargs...) where {T, V}
130+
x.data isa Backing && return x
131+
sizehint!(x.data, n; kwargs...)
132+
x
133+
end

src/types.jl

Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ const NO_METADATA = nothing
1616
sdict(kv...) = Dict{Any, Any}(kv...)
1717

1818
using Base: RefValue
19-
const EMPTY_ARGS = []
19+
const SmallV{T} = SmallVec{T, Vector{T}}
20+
const EMPTY_ARGS = SmallV{Any}()
2021
const EMPTY_HASH = RefValue(UInt(0))
2122
const EMPTY_DICT = sdict()
2223
const EMPTY_DICT_T = typeof(EMPTY_DICT)
@@ -31,7 +32,7 @@ const ENABLE_HASHCONSING = Ref(true)
3132
end
3233
struct Term{T} <: BasicSymbolic{T}
3334
f::Any = identity # base/num if Pow; issorted if Add/Dict
34-
arguments::Vector{Any} = EMPTY_ARGS
35+
arguments::SmallV{Any} = EMPTY_ARGS
3536
hash::RefValue{UInt} = EMPTY_HASH
3637
hash2::RefValue{UInt} = EMPTY_HASH
3738
end
@@ -40,25 +41,25 @@ const ENABLE_HASHCONSING = Ref(true)
4041
dict::EMPTY_DICT_T = EMPTY_DICT
4142
hash::RefValue{UInt} = EMPTY_HASH
4243
hash2::RefValue{UInt} = EMPTY_HASH
43-
arguments::Vector{Any} = EMPTY_ARGS
44+
arguments::SmallV{Any} = EMPTY_ARGS
4445
end
4546
struct Add{T} <: BasicSymbolic{T}
4647
coeff::Any = 0 # exp/den if Pow
4748
dict::EMPTY_DICT_T = EMPTY_DICT
4849
hash::RefValue{UInt} = EMPTY_HASH
4950
hash2::RefValue{UInt} = EMPTY_HASH
50-
arguments::Vector{Any} = EMPTY_ARGS
51+
arguments::SmallV{Any} = EMPTY_ARGS
5152
end
5253
struct Div{T} <: BasicSymbolic{T}
5354
num::Any = 1
5455
den::Any = 1
5556
simplified::Bool = false
56-
arguments::Vector{Any} = EMPTY_ARGS
57+
arguments::SmallV{Any} = EMPTY_ARGS
5758
end
5859
struct Pow{T} <: BasicSymbolic{T}
5960
base::Any = 1
6061
exp::Any = 1
61-
arguments::Vector{Any} = EMPTY_ARGS
62+
arguments::SmallV{Any} = EMPTY_ARGS
6263
end
6364
end
6465

@@ -564,17 +565,8 @@ function Sym{T}(name::Symbol; kw...) where {T}
564565
BasicSymbolic(s)
565566
end
566567

567-
function unwrap_arr!(arr)
568-
for i in eachindex(arr)
569-
arr[i] = unwrap(arr[i])
570-
end
571-
end
572-
573568
function Term{T}(f, args; kw...) where T
574-
if eltype(args) !== Any
575-
args = convert(Vector{Any}, args)
576-
end
577-
unwrap_arr!(args)
569+
args = SmallV{Any}(args)
578570

579571
s = Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), kw...)
580572
BasicSymbolic(s)
@@ -606,7 +598,7 @@ function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where T
606598
end
607599
end
608600

609-
s = Add{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), metadata, arguments=[], kw...)
601+
s = Add{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), metadata, arguments=SmallV{Any}(), kw...)
610602
BasicSymbolic(s)
611603
end
612604

@@ -624,7 +616,7 @@ function Mul(T, a, b; metadata=NO_METADATA, kw...)
624616
else
625617
coeff = a
626618
dict = b
627-
s = Mul{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), metadata, arguments=[], kw...)
619+
s = Mul{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), metadata, arguments=SmallV{Any}(), kw...)
628620
BasicSymbolic(s)
629621
end
630622
end
@@ -645,7 +637,7 @@ ratio(x::Rat,y::Rat) = x//y
645637
function maybe_intcoeff(x)
646638
if ismul(x)
647639
if x.coeff isa Rational && isone(x.coeff.den)
648-
Mul{symtype(x)}(; coeff=x.coeff.num, dict=x.dict, x.metadata, arguments=[])
640+
Mul{symtype(x)}(; coeff=x.coeff.num, dict=x.dict, x.metadata, arguments=SmallV{Any}())
649641
else
650642
x
651643
end
@@ -692,7 +684,7 @@ function Div{T}(n, d, simplified=false; metadata=nothing, kwargs...) where {T}
692684
end
693685
end
694686

695-
s = Div{T}(; num=n, den=d, simplified, arguments=[], metadata)
687+
s = Div{T}(; num=n, den=d, simplified, arguments=SmallV{Any}(), metadata)
696688
BasicSymbolic(s)
697689
end
698690

@@ -712,7 +704,7 @@ function Pow{T}(a, b; metadata=NO_METADATA, kwargs...) where {T}
712704
b = unwrap(b)
713705
_iszero(b) && return 1
714706
_isone(b) && return a
715-
s = Pow{T}(; base=a, exp=b, arguments=[], metadata)
707+
s = Pow{T}(; base=a, exp=b, arguments=SmallV{Any}(), metadata)
716708
BasicSymbolic(s)
717709
end
718710

@@ -728,13 +720,13 @@ function toterm(t::BasicSymbolic{T}) where T
728720
args = Any[]
729721
push!(args, t.coeff)
730722
for (k, coeff) in t.dict
731-
push!(args, coeff == 1 ? k : Term{T}(E === MUL ? (^) : (*), Any[coeff, k]))
723+
push!(args, coeff == 1 ? k : Term{T}(E === MUL ? (^) : (*), SmallV{Any}((coeff, k))))
732724
end
733725
Term{T}(operation(t), args)
734726
elseif E === DIV
735-
Term{T}(/, Any[t.num, t.den])
727+
Term{T}(/, SmallV{Any}((t.num, t.den)))
736728
elseif E === POW
737-
Term{T}(^, [t.base, t.exp])
729+
Term{T}(^, SmallV{Any}((t.base, t.exp)))
738730
else
739731
error_on_type()
740732
end
@@ -808,13 +800,13 @@ function makepow(a, b)
808800
end
809801

810802
function term(f, args...; type = nothing)
811-
args = map(unwrap, args)
803+
args = SmallV{Any}(args)
812804
if type === nothing
813805
T = _promote_symtype(f, args)
814806
else
815807
T = type
816808
end
817-
Term{T}(f, Any[args...])
809+
Term{T}(f, args)
818810
end
819811

820812
"""
@@ -826,7 +818,7 @@ function unflatten(t::Symbolic{T}) where{T}
826818
f = operation(t)
827819
if f == (+) || f == (*) # TODO check out for other n-ary --> binary ops
828820
a = arguments(t)
829-
return foldl((x,y) -> Term{T}(f, Any[x, y]), a)
821+
return foldl((x,y) -> Term{T}(f, SmallV{Any}((x, y))), a)
830822
end
831823
end
832824
return t
@@ -1192,7 +1184,7 @@ promote_symtype(f, Ts...) = Any
11921184

11931185
struct FnType{X<:Tuple,Y,Z} end
11941186

1195-
(f::Symbolic{<:FnType})(args...) = Term{promote_symtype(f, symtype.(args)...)}(f, Any[args...])
1187+
(f::Symbolic{<:FnType})(args...) = Term{promote_symtype(f, symtype.(args)...)}(f, SmallV{Any}(args))
11961188

11971189
function (f::Symbolic)(args...)
11981190
error("Sym $f is not callable. " *

0 commit comments

Comments
 (0)