Skip to content

Commit 635080d

Browse files
committed
Don't send views between workers
1 parent 1faf00f commit 635080d

File tree

1 file changed

+16
-19
lines changed

1 file changed

+16
-19
lines changed

src/darray.jl

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ function locate(d::DArray, I::Int...)
468468
end
469469
end
470470

471-
chunk(d::DArray{T,N,A}, i...) where {T,N,A} = remotecall_fetch(localpart, d.pids[i...], d)::A
471+
chunk(d::DArray{T,N,A}, pid::Int) where {T,N,A} = remotecall_fetch(localpart, pid, d)::A
472472

473473
## convenience constructors ##
474474

@@ -568,22 +568,22 @@ DArray{T,N,S}(A::S) where {T,N,S<:AbstractArray} = distribute(convert(AbstractAr
568568

569569
function Array{S,N}(d::DArray{T,N}) where {S,T,N}
570570
a = Array{S}(undef, size(d))
571-
@sync begin
572-
for i = 1:length(d.pids)
573-
@async a[d.indices[i]...] = chunk(d, i)
571+
@sync for (pid, indices) in zip(d.pids, d.indices)
572+
if !any(isempty, indices)
573+
@async a[indices...] = chunk(d, pid)
574574
end
575575
end
576576
return a
577577
end
578578

579579
function Array{S,N}(s::SubDArray{T,N}) where {S,T,N}
580580
I = s.indices
581-
d = s.parent
581+
d = parent(s)
582582
if isa(I,Tuple{Vararg{UnitRange{Int}}}) && S<:T && T<:S && !isempty(s)
583583
l = locate(d, map(first, I)...)
584584
if isequal(d.indices[l...], I)
585585
# SubDArray corresponds to a chunk
586-
return chunk(d, l...)
586+
return chunk(d, d.pids[l...])
587587
end
588588
end
589589
a = Array{S}(undef, size(s))
@@ -697,10 +697,11 @@ end
697697

698698
function Base.setindex!(a::Array, d::DArray,
699699
I::Union{UnitRange{Int},Colon,Vector{Int},StepRange{Int,Int}}...)
700-
n = length(I)
701-
@sync for i = 1:length(d.pids)
702-
K = d.indices[i]
703-
@async a[[I[j][K[j]] for j=1:n]...] = chunk(d, i)
700+
@sync for (pid, K) in zip(d.pids, d.indices)
701+
idxs = map((Ij, Kj) -> Ij[Kj], I, K)
702+
if !any(isempty, idxs)
703+
@async a[idxs...] = chunk(d, pid)
704+
end
704705
end
705706
return a
706707
end
@@ -809,24 +810,20 @@ function Base.setindex!(a::Array, s::SubDArray,
809810
I::Union{UnitRange{Int},Colon,Vector{Int},StepRange{Int,Int}}...)
810811
Inew = Base.to_indices(a, I)
811812
Base.setindex_shape_check(s, Base.index_lengths(Inew...)...)
812-
n = length(Inew)
813-
d = s.parent
813+
d = parent(s)
814814
J = Base.to_indices(d, s.indices)
815-
@sync for i = 1:length(d.pids)
816-
K_c = d.indices[i]
815+
@sync for (pid, K_c) in zip(d.pids, d.indices)
817816
K = map(intersect, J, K_c)
818817
if !any(isempty, K)
819818
K_mask = map(indexin_mask, J, K_c)
820819
idxs = restrict_indices(Inew, K_mask)
821820
if isequal(K, K_c)
822821
# whole chunk
823-
@async a[idxs...] = chunk(d, i)
822+
@async a[idxs...] = chunk(d, pid)
824823
else
825824
# partial chunk
826-
@async a[idxs...] =
827-
remotecall_fetch(d.pids[i]) do
828-
view(localpart(d), [K[j].-first(K_c[j]).+1 for j=1:length(J)]...)
829-
end
825+
localidxs = map((Kj, K_cj) -> Kj .- (first(K_cj) - 1), K, K_c)
826+
@async a[idxs...] = remotecall_fetch((d, idxs) -> localpart(d)[idxs...], pid, d, localidxs)
830827
end
831828
end
832829
end

0 commit comments

Comments
 (0)