diff --git a/src/darray.jl b/src/darray.jl index d3868ad..a2a308e 100644 --- a/src/darray.jl +++ b/src/darray.jl @@ -468,7 +468,7 @@ function locate(d::DArray, I::Int...) end end -chunk(d::DArray{T,N,A}, i...) where {T,N,A} = remotecall_fetch(localpart, d.pids[i...], d)::A +chunk(d::DArray{T,N,A}, pid::Int) where {T,N,A} = remotecall_fetch(localpart, pid, d)::A ## convenience constructors ## @@ -568,9 +568,9 @@ DArray{T,N,S}(A::S) where {T,N,S<:AbstractArray} = distribute(convert(AbstractAr function Array{S,N}(d::DArray{T,N}) where {S,T,N} a = Array{S}(undef, size(d)) - @sync begin - for i = 1:length(d.pids) - @async a[d.indices[i]...] = chunk(d, i) + @sync for (pid, indices) in zip(d.pids, d.indices) + if !any(isempty, indices) + @async a[indices...] = chunk(d, pid) end end return a @@ -578,12 +578,12 @@ end function Array{S,N}(s::SubDArray{T,N}) where {S,T,N} I = s.indices - d = s.parent + d = parent(s) if isa(I,Tuple{Vararg{UnitRange{Int}}}) && S<:T && T<:S && !isempty(s) l = locate(d, map(first, I)...) if isequal(d.indices[l...], I) # SubDArray corresponds to a chunk - return chunk(d, l...) + return chunk(d, d.pids[l...]) end end a = Array{S}(undef, size(s)) @@ -697,10 +697,11 @@ end function Base.setindex!(a::Array, d::DArray, I::Union{UnitRange{Int},Colon,Vector{Int},StepRange{Int,Int}}...) - n = length(I) - @sync for i = 1:length(d.pids) - K = d.indices[i] - @async a[[I[j][K[j]] for j=1:n]...] = chunk(d, i) + @sync for (pid, K) in zip(d.pids, d.indices) + idxs = map((Ij, Kj) -> Ij[Kj], I, K) + if !any(isempty, idxs) + @async a[idxs...] = chunk(d, pid) + end end return a end @@ -809,24 +810,20 @@ function Base.setindex!(a::Array, s::SubDArray, I::Union{UnitRange{Int},Colon,Vector{Int},StepRange{Int,Int}}...) Inew = Base.to_indices(a, I) Base.setindex_shape_check(s, Base.index_lengths(Inew...)...) - n = length(Inew) - d = s.parent + d = parent(s) J = Base.to_indices(d, s.indices) - @sync for i = 1:length(d.pids) - K_c = d.indices[i] + @sync for (pid, K_c) in zip(d.pids, d.indices) K = map(intersect, J, K_c) if !any(isempty, K) K_mask = map(indexin_mask, J, K_c) idxs = restrict_indices(Inew, K_mask) if isequal(K, K_c) # whole chunk - @async a[idxs...] = chunk(d, i) + @async a[idxs...] = chunk(d, pid) else # partial chunk - @async a[idxs...] = - remotecall_fetch(d.pids[i]) do - view(localpart(d), [K[j].-first(K_c[j]).+1 for j=1:length(J)]...) - end + localidxs = map((Kj, K_cj) -> Kj .- (first(K_cj) - 1), K, K_c) + @async a[idxs...] = remotecall_fetch((d, idxs) -> localpart(d)[idxs...], pid, d, localidxs) end end end