Skip to content

Commit 492a4f8

Browse files
authored
Fix data races in lazy construction and show() (#7)
* Fix data races in lazy construction and show() Need to always lock accesses to base_cache :) - Forgot to lock during lazy construction in thread_cache(tid) - Added a lock+copy to show(), to avoid holding the lock while printing to IO, which may block arbitrarily long * Perf improvement for Base.show: print to buffer instead of copy Add test for show * Add 3-arg show as well * Fix typo in setindex: I always mix up the order of those arguments.. * Add test for different KV types
1 parent 38daaea commit 492a4f8

File tree

2 files changed

+58
-4
lines changed

2 files changed

+58
-4
lines changed

src/MultiThreadedCaches.jl

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,40 @@ function init_cache!(cache::MultiThreadedCache{K,V}) where {K,V}
8888
end
8989

9090
function Base.show(io::IO, cache::MultiThreadedCache{K,V}) where {K,V}
91+
# Contention optimization: don't hold the lock while printing, since that could block
92+
# for an arbitrarily long time. Instead, print the data to an intermediate buffer first.
93+
# Note that this has the same CPU complexity, since printing is already O(n).
94+
iobuf = IOBuffer()
95+
let io = IOContext(iobuf, io)
96+
Base.@lock cache.base_cache_lock begin
97+
_oneline_show(io, cache)
98+
end
99+
end
100+
# Now print the data without holding the lock.
101+
seekstart(iobuf)
102+
write(io, read(iobuf))
103+
return nothing
104+
end
105+
_oneline_show(io::IO, cache::MultiThreadedCache{K,V}) where {K,V} =
91106
print(io, "$(MultiThreadedCache{K,V})(", cache.base_cache, ")")
107+
108+
function Base.show(io::IO, mime::MIME"text/plain", cache::MultiThreadedCache{K,V}) where {K,V}
109+
# Contention optimization: don't hold the lock while printing. See above for more info.
110+
iobuf = IOBuffer()
111+
let io = IOContext(iobuf, io)
112+
Base.@lock cache.base_cache_lock begin
113+
if isempty(cache.base_cache)
114+
_oneline_show(io, cache)
115+
else
116+
print(io, "$(MultiThreadedCache): ")
117+
Base.show(io, mime, cache.base_cache)
118+
end
119+
end
120+
end
121+
# Now print the data without holding the lock.
122+
seekstart(iobuf)
123+
write(io, read(iobuf))
124+
return nothing
92125
end
93126

94127
# Based upon the thread-safe Global RNG implementation in the Random stdlib:
@@ -101,7 +134,9 @@ function _thread_cache(mtcache::MultiThreadedCache, tid)
101134
else
102135
# We copy the base cache to all the thread caches, so that any precomputed values
103136
# can be shared without having to wait for a cache miss.
104-
cache = copy(mtcache.base_cache)
137+
cache = Base.@lock mtcache.base_cache_lock begin
138+
copy(mtcache.base_cache)
139+
end
105140
@inbounds mtcache.thread_caches[tid] = cache
106141
end
107142
return cache
@@ -117,7 +152,6 @@ function _thread_lock(cache::MultiThreadedCache, tid)
117152
end
118153
return lock
119154
end
120-
@noinline _thread_cache_length_assert() = @assert false "** Must call `init_cache!(cache)` in your Module's __init__()! - length(cache.thread_caches) < Threads.nthreads() "
121155

122156

123157
const CACHE_MISS = :__MultiThreadedCaches_key_not_found__
@@ -151,10 +185,10 @@ function Base.get!(func::Base.Callable, cache::MultiThreadedCache{K,V}, key) whe
151185
Base.@lock tlock begin
152186
if test_haskey
153187
if !haskey(tcache, key)
154-
setindex!(tcache, key, v)
188+
tcache[key] = v
155189
end
156190
else
157-
setindex!(tcache, key, v)
191+
tcache[key] = v
158192
end
159193
end
160194
end

test/MultiThreadedCaches.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,20 @@ using Test
3131
# Internals test:
3232
@test length(cache.base_cache_futures) == 0
3333
end
34+
@testset "KV types" begin
35+
cache = MultiThreadedCache{Int,String}()
36+
init_cache!(cache)
37+
38+
@test get!(()->"hi", cache, 1) == "hi"
39+
@test get!(()->"bye", cache, 1) == "hi"
40+
41+
cache = MultiThreadedCache{Any,Any}()
42+
init_cache!(cache)
43+
44+
@test get!(()->"hi", cache, 1) == "hi"
45+
@test get!(()->2.0, cache, 1) == "hi"
46+
@test get!(()->3.0, cache, 1.0) == "hi"
47+
end
3448

3549
# Helper function for stress-test: returns true if all elements in iterable `v` are equal.
3650
function all_equal(v)
@@ -121,6 +135,12 @@ end
121135
@test cache.base_cache == Dict(1=>10)
122136
end
123137

138+
@testset "show" begin
139+
cache = MultiThreadedCache{Int64, Int64}()
140+
# Exercise both show functions
141+
show(cache)
142+
display(cache)
143+
end
124144

125145

126146

0 commit comments

Comments
 (0)