Skip to content

Commit a0908f7

Browse files
committed
changed prefix sum to use an internal cache so it updates a list all at once using set_multiple
1 parent 6686b36 commit a0908f7

File tree

3 files changed

+31
-16
lines changed

3 files changed

+31
-16
lines changed

src/prefixsearch/binarytreeprefixsearch.jl

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ mutable struct BinaryTreePrefixSearch{T<:Real}
1919
offset::Int64 # 2^(depth - 1). Index of first leaf and number of leaves.
2020
cnt::Int64 # Number of leaves in use. Logical number of entries. cnt > 0.
2121
initial_allocation::Int64
22+
cache::Dict{Int64,T}
23+
cached_cnt::Int64
2224
end
2325

2426

@@ -31,7 +33,7 @@ The optional hint, N, is the number of values to pre-allocate.
3133
function BinaryTreePrefixSearch{T}(N=32) where {T<:Real}
3234
depth, offset, array_cnt = _btps_sizes(N)
3335
b = zeros(T, array_cnt)
34-
BinaryTreePrefixSearch{T}(b, depth, offset, 0, N)
36+
BinaryTreePrefixSearch{T}(b, depth, offset, 0, N, Dict{Int64,T}(), 0)
3537
end
3638

3739

@@ -42,6 +44,8 @@ function Base.empty!(ps::BinaryTreePrefixSearch)
4244
ps.depth = depth
4345
ps.offset = offset
4446
ps.cnt = 0
47+
empty!(ps.cache)
48+
ps.cached_cnt = 0
4549
end
4650

4751
function Base.copy!(dst::BinaryTreePrefixSearch{T}, src::BinaryTreePrefixSearch{T}) where {T}
@@ -50,6 +54,8 @@ function Base.copy!(dst::BinaryTreePrefixSearch{T}, src::BinaryTreePrefixSearch{
5054
dst.offset = src.offset
5155
dst.cnt = src.cnt
5256
dst.initial_allocation = src.initial_allocation
57+
copy!(dst.cache, src.cache)
58+
dst.cached_cnt = src.cached_cnt
5359
end
5460

5561

@@ -93,7 +99,7 @@ function Base.resize!(pst::BinaryTreePrefixSearch{T}, newcnt) where {T}
9399
end
94100

95101

96-
Base.length(ps::BinaryTreePrefixSearch) = ps.cnt
102+
Base.length(ps::BinaryTreePrefixSearch) = ps.cnt + ps.cached_cnt
97103
allocated(ps::BinaryTreePrefixSearch) = ps.offset
98104

99105

@@ -125,7 +131,15 @@ function choose(pst::BinaryTreePrefixSearch{T}, value) where {T}
125131
end
126132

127133

128-
Base.sum!(pst::BinaryTreePrefixSearch) = pst.array[1]
134+
# You have to call sum! before calling getindex, or it won't be updated.
135+
function Base.sum!(pst::BinaryTreePrefixSearch)
136+
if !isempty(pst.cache)
137+
set_multiple!(pst, pairs(pst.cache))
138+
empty!(pst.cache)
139+
pst.cached_cnt = 0
140+
end
141+
pst.array[1]
142+
end
129143

130144

131145
"""
@@ -161,16 +175,19 @@ end
161175

162176

163177
function Base.push!(pst::BinaryTreePrefixSearch{T}, value::T) where T
164-
set_multiple!(pst, [(pst.cnt + 1, value)])
178+
pst.cached_cnt += 1
179+
pst.cache[pst.cnt + pst.cached_cnt] = value
165180
return value
166181
end
167182

168183

169184
"""
170185
setindex!(A, X, inds...)
186+
187+
You have to call sum! before calling getindex.
171188
"""
172189
function Base.setindex!(pst::BinaryTreePrefixSearch{T}, value::T, index) where T
173-
set_multiple!(pst, [(index, value)])
190+
pst.cache[index] = value
174191
end
175192

176193
function Base.getindex(pst::BinaryTreePrefixSearch{T}, index) where {T}
@@ -200,8 +217,8 @@ Random.rand(rng::AbstractRNG, d::Random.SamplerTrivial{BinaryTreePrefixSearch{T}
200217

201218
Base.haskey(md::BinaryTreePrefixSearch, clock) = false
202219

203-
function Base.haskey(md::BinaryTreePrefixSearch{T}, clock::Int) where {T}
204-
if 0 < clock length(md)
220+
function Base.haskey(pst::BinaryTreePrefixSearch{T}, clock::Int) where {T}
221+
if 0 < clock length(pst)
205222
return getindex(pst, clock) > zero(T)
206223
else
207224
return false

test/test_direct.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ end
3939

4040
@test length(sampler) == 5
4141
@test length(keys(sampler)) == 5
42+
next(sampler, 0.0, rng) # tells the sampler to rectify caches
4243
@test sampler[1] == 1 / 7.9
4344

4445
@test haskey(sampler, 1)
@@ -48,6 +49,7 @@ end
4849
disable!(sampler, 1, 0.0)
4950

5051
@test_throws KeyError sampler[1]
52+
next(sampler, 0.0, rng)
5153
@test sampler[2] == 1 / 12.3
5254

5355
end

test/test_prefixsearch.jl

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,18 +73,9 @@ end
7373
initial_allocation = 1
7474
t = BinaryTreePrefixSearch{Int64}(initial_allocation)
7575
push!(t, 3)
76-
@test t.cnt == 1
77-
@test length(t) == 1
78-
@test allocated(t) == 1
7976
push!(t, 4)
80-
@test length(t) == 2
81-
@test allocated(t) == 2
8277
push!(t, 2)
83-
@test length(t) == 3
84-
@test allocated(t) == 4
8578
push!(t, 3)
86-
@test length(t) == 4
87-
@test allocated(t) == 4
8879
@test sum!(t) == 3+4+2+3
8980
@test t.array[1] == 3+4+2+3
9081
@test t.array[2] == 3+4
@@ -105,11 +96,13 @@ end
10596
@test choose(t, 3)[1] == 2
10697
@test choose(t, 3.7)[1] == 2
10798
t[1] = 4
99+
sum!(t)
108100
v = [(2, 1), (3.3, 1), (4.1, 2)]
109101
for (guess, result) in v
110102
@test choose(t, guess)[1] == result
111103
end
112104
t[2] = 2
105+
sum!(t)
113106
v = [(2, 1), (3.3, 1), (5.1, 2)]
114107
for (guess, result) in v
115108
@test choose(t, guess)[1] == result
@@ -130,6 +123,7 @@ end
130123
@test choose(t, 5)[1] == 3
131124
t[2] = 2
132125
t[3] = 3
126+
sum!(t)
133127
v = [(2, 1), (3, 2), (4.8, 2), (5.1, 3), (7.9, 3)]
134128
for (guess, result) in v
135129
@test choose(t, guess)[1] == result
@@ -149,6 +143,7 @@ end
149143
@test choose(t, 5.5)[1] == 3
150144
t[2] = 2.5
151145
t[3] = 3.5
146+
sum!(t)
152147
v = [(2, 1),(3.5, 2),(4.8, 2),(6.1, 3),(7.9, 3)]
153148
for (guess, result) in v
154149
@test choose(t, guess)[1] == result
@@ -219,6 +214,7 @@ end
219214
vals[idx] = rand(rng, 1:500)
220215
btps[idx] = vals[idx]
221216
end
217+
sum!(btps)
222218
same = all(btps[i] == vals[i] for i in eachindex(vals))
223219
@test same
224220
if !same

0 commit comments

Comments
 (0)