Skip to content

Commit b2cff52

Browse files
committed
Make DistributedArrays thread-safe
1 parent 1faf00f commit b2cff52

File tree

6 files changed

+176
-138
lines changed

6 files changed

+176
-138
lines changed

src/core.jl

Lines changed: 82 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,109 @@
1-
const registry=Dict{Tuple, Any}()
2-
const refs=Set() # Collection of darray identities created on this node
1+
# Thread-safe registry of DArray references
2+
struct DArrayRegistry
3+
data::Dict{Tuple{Int,Int}, Any}
4+
lock::ReentrantLock
5+
DArrayRegistry() = new(Dict{Tuple{Int,Int}, Any}(), ReentrantLock())
6+
end
7+
const REGISTRY = DArrayRegistry()
8+
9+
function Base.get(r::DArrayRegistry, id::Tuple{Int,Int}, default)
10+
@lock r.lock begin
11+
return get(r.data, id, default)
12+
end
13+
end
14+
function Base.getindex(r::DArrayRegistry, id::Tuple{Int,Int})
15+
@lock r.lock begin
16+
return r.data[id]
17+
end
18+
end
19+
function Base.setindex!(r::DArrayRegistry, val, id::Tuple{Int,Int})
20+
@lock r.lock begin
21+
r.data[id] = val
22+
end
23+
return r
24+
end
25+
function Base.delete!(r::DArrayRegistry, id::Tuple{Int,Int})
26+
@lock r.lock delete!(r.data, id)
27+
return r
28+
end
29+
30+
# Thread-safe set of IDs of DArrays created on this node
31+
struct DArrayRefs
32+
data::Set{Tuple{Int,Int}}
33+
lock::ReentrantLock
34+
DArrayRefs() = new(Set{Tuple{Int,Int}}(), ReentrantLock())
35+
end
36+
const REFS = DArrayRefs()
337

4-
let DID::Int = 1
5-
global next_did
6-
next_did() = (id = DID; DID += 1; (myid(), id))
38+
function Base.push!(r::DArrayRefs, id::Tuple{Int,Int})
39+
# Ensure id refers to a DArray created on this node
40+
if first(id) != myid()
41+
throw(
42+
ArgumentError(
43+
lazy"`DArray` is not created on the current worker: Only `DArray`s created on worker $(myid()) can be stored in this set but the `DArray` was created on worker $(first(id))."))
44+
end
45+
@lock r.lock begin
46+
return push!(r.data, id)
47+
end
48+
end
49+
function Base.delete!(r::DArrayRefs, id::Tuple{Int,Int})
50+
@lock r.lock delete!(r.data, id)
51+
return r
752
end
853

54+
# Global counter to generate a unique ID for each DArray
55+
const DID = Threads.Atomic{Int}(1)
56+
957
"""
1058
next_did()
1159
12-
Produces an incrementing ID that will be used for DArrays.
13-
"""
14-
next_did
60+
Increment a global counter and return a tuple of the current worker ID and the incremented
61+
value of the counter.
1562
16-
release_localpart(id::Tuple) = (delete!(registry, id); nothing)
17-
release_localpart(d) = release_localpart(d.id)
63+
This tuple is used as a unique ID for a new `DArray`.
64+
"""
65+
next_did() = (myid(), Threads.atomic_add!(DID, 1))
1866

19-
function close_by_id(id, pids)
20-
# @async println("Finalizer for : ", id)
21-
global refs
67+
release_localpart(id::Tuple{Int,Int}) = (delete!(REGISTRY, id); nothing)
68+
function release_allparts(id::Tuple{Int,Int}, pids::Array{Int})
2269
@sync begin
70+
released_myid = false
2371
for p in pids
24-
@async remotecall_fetch(release_localpart, p, id)
72+
if p == myid()
73+
@async release_localpart(id)
74+
released_myid = true
75+
else
76+
@async remotecall_fetch(release_localpart, p, id)
77+
end
2578
end
26-
if !(myid() in pids)
27-
release_localpart(id)
79+
if !released_myid
80+
@async release_localpart(id)
2881
end
2982
end
30-
delete!(refs, id)
31-
nothing
83+
return nothing
3284
end
3385

34-
function Base.close(d::DArray)
35-
# @async println("close : ", d.id, ", object_id : ", object_id(d), ", myid : ", myid() )
36-
if (myid() == d.id[1]) && d.release
37-
@async close_by_id(d.id, d.pids)
38-
d.release = false
39-
end
86+
function close_by_id(id::Tuple{Int,Int}, pids::Array{Int})
87+
release_allparts(id, pids)
88+
delete!(REFS, id)
4089
nothing
4190
end
4291

4392
function d_closeall()
44-
crefs = copy(refs)
45-
for id in crefs
46-
if id[1] == myid() # sanity check
47-
if haskey(registry, id)
48-
d = d_from_weakref_or_d(id)
49-
(d === nothing) || close(d)
93+
@lock REFS.lock begin
94+
while !isempty(REFS.data)
95+
id = pop!(REFS.data)
96+
d = d_from_weakref_or_d(id)
97+
if d isa DArray
98+
finalize(d)
5099
end
51-
yield()
52100
end
53101
end
102+
return nothing
54103
end
55104

105+
Base.close(d::DArray) = finalize(d)
106+
56107
"""
57108
procs(d::DArray)
58109
@@ -67,4 +118,3 @@ Distributed.procs(d::SubDArray) = procs(parent(d))
67118
The identity when input is not distributed
68119
"""
69120
localpart(A) = A
70-

src/darray.jl

Lines changed: 32 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -23,44 +23,40 @@ dfill(v, args...) = DArray(I->fill(v, map(length,I)), args...)
2323
```
2424
"""
2525
mutable struct DArray{T,N,A} <: AbstractArray{T,N}
26-
id::Tuple
26+
id::Tuple{Int,Int}
2727
dims::NTuple{N,Int}
2828
pids::Array{Int,N} # pids[i]==p ⇒ processor p has piece i
2929
indices::Array{NTuple{N,UnitRange{Int}},N} # indices held by piece i
3030
cuts::Vector{Vector{Int}} # cuts[d][i] = first index of chunk i in dimension d
3131
localpart::Union{A,Nothing}
32-
release::Bool
3332

34-
function DArray{T,N,A}(id, dims, pids, indices, cuts, lp) where {T,N,A}
33+
function DArray{T,N,A}(id::Tuple{Int,Int}, dims::NTuple{N,Int}, pids, indices, cuts, lp) where {T,N,A}
3534
# check invariants
3635
if dims != map(last, last(indices))
3736
throw(ArgumentError("dimension of DArray (dim) and indices do not match"))
3837
end
39-
release = (myid() == id[1])
4038

4139
d = d_from_weakref_or_d(id)
4240
if d === nothing
43-
d = new(id, dims, pids, indices, cuts, lp, release)
41+
d = new(id, dims, pids, indices, cuts, lp)
4442
end
4543

46-
if release
47-
push!(refs, id)
48-
registry[id] = WeakRef(d)
49-
50-
# println("Installing finalizer for : ", d.id, ", : ", object_id(d), ", isbits: ", isbits(d))
51-
finalizer(close, d)
44+
if first(id) == myid()
45+
push!(REFS, id)
46+
REGISTRY[id] = WeakRef(d)
47+
finalizer(d) do d
48+
@async close_by_id(d.id, d.pids)
49+
end
5250
end
5351
d
5452
end
5553

5654
DArray{T,N,A}() where {T,N,A} = new()
5755
end
5856

59-
function d_from_weakref_or_d(id)
60-
d = get(registry, id, nothing)
61-
isa(d, WeakRef) && return d.value
62-
return d
63-
end
57+
unpack_weakref(x) = x
58+
unpack_weakref(x::WeakRef) = x.value
59+
d_from_weakref_or_d(id::Tuple{Int,Int}) = unpack_weakref(get(REGISTRY, id, nothing))
6460

6561
Base.eltype(::Type{DArray{T}}) where {T} = T
6662
empty_localpart(T,N,A) = A(Array{T}(undef, ntuple(zero, N)))
@@ -77,41 +73,34 @@ Base.hash(d::DArray, h::UInt) = Base.hash(d.id, h)
7773

7874
## core constructors ##
7975

80-
function DArray(id, init, dims, pids, idxs, cuts)
76+
function DArray(id::Tuple{Int,Int}, init::I, dims, pids, idxs, cuts) where {I}
8177
localtypes = Vector{DataType}(undef,length(pids))
82-
83-
@sync begin
84-
for i = 1:length(pids)
85-
@async begin
86-
local typA
87-
if isa(init, Function)
88-
typA = remotecall_fetch(construct_localparts, pids[i], init, id, dims, pids, idxs, cuts)
89-
else
90-
# constructing from an array of remote refs.
91-
typA = remotecall_fetch(construct_localparts, pids[i], init[i], id, dims, pids, idxs, cuts)
92-
end
93-
localtypes[i] = typA
94-
end
78+
if init isa Function
79+
asyncmap!(localtypes, pids) do pid
80+
return remotecall_fetch(construct_localparts, pid, init, id, dims, pids, idxs, cuts)
81+
end
82+
else
83+
asyncmap!(localtypes, pids, init) do pid, pid_init
84+
# constructing from an array of remote refs.
85+
return remotecall_fetch(construct_localparts, pid, pid_init, id, dims, pids, idxs, cuts)
9586
end
9687
end
9788

98-
if length(unique(localtypes)) != 1
89+
if !allequal(localtypes)
9990
@sync for p in pids
10091
@async remotecall_fetch(release_localpart, p, id)
10192
end
102-
throw(ErrorException("Constructed localparts have different `eltype`: $(localtypes)"))
93+
throw(ErrorException(lazy"Constructed localparts have different `eltype`: $(localtypes)"))
10394
end
10495
A = first(localtypes)
10596

10697
if myid() in pids
107-
d = registry[id]
108-
d = isa(d, WeakRef) ? d.value : d
98+
return unpack_weakref(REGISTRY[id])
10999
else
110100
T = eltype(A)
111101
N = length(dims)
112-
d = DArray{T,N,A}(id, dims, pids, idxs, cuts, empty_localpart(T,N,A))
102+
return DArray{T,N,A}(id, dims, pids, idxs, cuts, empty_localpart(T,N,A))
113103
end
114-
d
115104
end
116105

117106
function construct_localparts(init, id, dims, pids, idxs, cuts; T=nothing, A=nothing)
@@ -124,7 +113,7 @@ function construct_localparts(init, id, dims, pids, idxs, cuts; T=nothing, A=not
124113
end
125114
N = length(dims)
126115
d = DArray{T,N,A}(id, dims, pids, idxs, cuts, localpart)
127-
registry[id] = d
116+
REGISTRY[id] = d
128117
A
129118
end
130119

@@ -152,12 +141,10 @@ function ddata(;T::Type=Any, init::Function=I->nothing, pids=workers(), data::Ve
152141
end
153142

154143
if myid() in pids
155-
d = registry[id]
156-
d = isa(d, WeakRef) ? d.value : d
144+
return unpack_weakref(REGISTRY[id])
157145
else
158-
d = DArray{T,1,T}(id, (npids,), pids, idxs, cuts, nothing)
146+
return DArray{T,1,T}(id, (npids,), pids, idxs, cuts, nothing)
159147
end
160-
d
161148
end
162149

163150
function gather(d::DArray{T,1,T}) where T
@@ -428,7 +415,7 @@ end
428415
function Base.:(==)(d::SubDArray, a::AbstractArray)
429416
cd = copy(d)
430417
t = cd == a
431-
close(cd)
418+
finalize(cd)
432419
return t
433420
end
434421
Base.:(==)(a::AbstractArray, d::DArray) = d == a
@@ -437,19 +424,19 @@ Base.:(==)(d1::DArray, d2::DArray) = invoke(==, Tuple{DArray, AbstractArray}, d1
437424
function Base.:(==)(d1::SubDArray, d2::DArray)
438425
cd1 = copy(d1)
439426
t = cd1 == d2
440-
close(cd1)
427+
finalize(cd1)
441428
return t
442429
end
443430
function Base.:(==)(d1::DArray, d2::SubDArray)
444431
cd2 = copy(d2)
445432
t = d1 == cd2
446-
close(cd2)
433+
finalize(cd2)
447434
return t
448435
end
449436
function Base.:(==)(d1::SubDArray, d2::SubDArray)
450437
cd1 = copy(d1)
451438
t = cd1 == d2
452-
close(cd1)
439+
finalize(cd1)
453440
return t
454441
end
455442

@@ -845,4 +832,3 @@ function Random.rand!(A::DArray, ::Type{T}) where T
845832
remotecall_wait((A, T)->rand!(localpart(A), T), p, A, T)
846833
end
847834
end
848-

0 commit comments

Comments
 (0)